From 7d5e6392c7835883a857071a922c6ff27d45fe92 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 27 May 2021 13:52:06 +0100 Subject: [PATCH 001/176] torch default for Flip Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 8 +++++--- monai/transforms/transform.py | 26 +++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 96d443c193..abd6eb0a6e 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -310,14 +310,16 @@ class Flip(Transform): def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: self.spatial_axis = spatial_axis - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ + img, input_is_numpy = self.pre_conv_data(img) - result: np.ndarray = np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)) - return result.astype(img.dtype) + result = torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)).to(img.dtype) # type: ignore + + return self.post_convert_data(result, input_is_numpy) class Resize(Transform): diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index e5715ee702..f6b5fde71b 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -197,6 +197,30 @@ def __call__(self, data: Any): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + def pre_conv_data( + self, data: Union[torch.Tensor, np.ndarray], requires_numpy: bool = False + ) -> Union[Union[torch.Tensor, np.ndarray], bool]: + """If input is in numpy, convert to torch. Also return the original state so that after the transform, + the data can be reverted to its original type. + """ + input_is_numpy = isinstance(data, np.ndarray) + if input_is_numpy and not requires_numpy: + data = torch.Tensor(data) + if requires_numpy and not input_is_numpy: + data = data.detach().cpu().numpy() # type: ignore + return data, input_is_numpy + + def post_convert_data( + self, data: Union[torch.Tensor, np.ndarray], to_numpy: bool + ) -> Union[torch.Tensor, np.ndarray]: + """Convert back to original type.""" + is_numpy = isinstance(data, np.ndarray) + if is_numpy and not to_numpy: + data = torch.Tensor(data) + if to_numpy and not is_numpy: + data = data.detach().cpu().numpy() + return data + class RandomizableTransform(Randomizable, Transform): """ From b01d6bc683dfc177ee7e2e11cfe92375ddba92f6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 27 May 2021 13:55:10 +0100 Subject: [PATCH 002/176] correct doc Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index f6b5fde71b..aac60ca19a 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -200,7 +200,7 @@ def __call__(self, data: Any): def pre_conv_data( self, data: Union[torch.Tensor, np.ndarray], requires_numpy: bool = False ) -> Union[Union[torch.Tensor, np.ndarray], bool]: - """If input is in numpy, convert to torch. Also return the original state so that after the transform, + """Convert to torch/numpy, as required. Also return the original state so that after the transform, the data can be reverted to its original type. """ input_is_numpy = isinstance(data, np.ndarray) From 218a93c07373cdb142f016497d53ec11db1ba075 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 27 May 2021 14:16:26 +0100 Subject: [PATCH 003/176] return type Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index aac60ca19a..6093719bf4 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -199,7 +199,7 @@ def __call__(self, data: Any): def pre_conv_data( self, data: Union[torch.Tensor, np.ndarray], requires_numpy: bool = False - ) -> Union[Union[torch.Tensor, np.ndarray], bool]: + ) -> Tuple[Union[torch.Tensor, np.ndarray], bool]: """Convert to torch/numpy, as required. Also return the original state so that after the transform, the data can be reverted to its original type. """ From 7911ee505e4ca2f1e06ba10741904129db03fb94 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 27 May 2021 15:15:24 +0100 Subject: [PATCH 004/176] typing Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/inverse.py | 5 ++--- monai/transforms/spatial/array.py | 9 +++++---- monai/transforms/spatial/dictionary.py | 25 ++++++++----------------- monai/transforms/transform.py | 22 ++++++++++++++-------- monai/utils/__init__.py | 1 + monai/utils/enums.py | 12 ++++++++++++ 6 files changed, 42 insertions(+), 32 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 3baef91717..b62bfcab3b 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -9,9 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Hashable, Optional, Tuple +from typing import Hashable, Optional, Tuple -import numpy as np import torch from monai.transforms.transform import RandomizableTransform, Transform @@ -110,7 +109,7 @@ def pop_transform(self, data: dict, key: Hashable) -> None: """Remove most recent transform.""" data[str(key) + InverseKeys.KEY_SUFFIX].pop() - def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: dict) -> dict: """ Inverse of ``__call__``. diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index abd6eb0a6e..de089e8cbc 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -45,6 +45,7 @@ issequenceiterable, optional_import, ) +from monai.utils.enums import TransformTypes nib, _ = optional_import("nibabel") @@ -310,7 +311,7 @@ class Flip(Transform): def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: self.spatial_axis = spatial_axis - def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: TransformTypes.Images) -> TransformTypes.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), @@ -777,7 +778,7 @@ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int] RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: TransformTypes.Images) -> TransformTypes.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), @@ -803,11 +804,11 @@ def __init__(self, prob: float = 0.1) -> None: RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None - def randomize(self, data: np.ndarray) -> None: + def randomize(self, data: TransformTypes.Images) -> None: super().randomize(None) self._axis = self.R.randint(data.ndim - 1) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: TransformTypes.Images) -> TransformTypes.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 4a4f039d28..de56d6717b 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -53,7 +53,7 @@ ensure_tuple_rep, fall_back_tuple, ) -from monai.utils.enums import InverseKeys +from monai.utils.enums import InverseKeys, TransformTypes from monai.utils.module import optional_import nib, _ = optional_import("nibabel") @@ -1078,20 +1078,17 @@ def __init__( super().__init__(keys, allow_missing_keys) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.flipper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: d = deepcopy(dict(data)) for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Inverse is same as forward d[key] = self.flipper(d[key]) # Remove the applied transform @@ -1127,7 +1124,7 @@ def __init__( self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: self.randomize(None) d = dict(data) for key in self.key_iterator(d): @@ -1136,15 +1133,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.push_transform(d, key) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform[InverseKeys.DO_TRANSFORM]: - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Inverse is same as forward d[key] = self.flipper(d[key]) # Remove the applied transform @@ -1171,11 +1165,11 @@ def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None - def randomize(self, data: np.ndarray) -> None: + def randomize(self, data: TransformTypes.Images) -> None: super().randomize(None) self._axis = self.R.randint(data.ndim - 1) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: self.randomize(data=data[self.keys[0]]) flipper = Flip(spatial_axis=self._axis) @@ -1186,16 +1180,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.push_transform(d, key, extra_info={"axis": self._axis}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform[InverseKeys.DO_TRANSFORM]: flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO]["axis"]) - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Inverse is same as forward d[key] = flipper(d[key]) # Remove the applied transform diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 6093719bf4..8d65365f44 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple import numpy as np import torch @@ -22,8 +22,16 @@ from monai import transforms from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple +from monai.utils.enums import TransformTypes -__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = [ + "ThreadUnsafe", + "apply_transform", + "Randomizable", + "RandomizableTransform", + "Transform", + "MapTransform", +] def apply_transform(transform: Callable, data, map_items: bool = True): @@ -198,8 +206,8 @@ def __call__(self, data: Any): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") def pre_conv_data( - self, data: Union[torch.Tensor, np.ndarray], requires_numpy: bool = False - ) -> Tuple[Union[torch.Tensor, np.ndarray], bool]: + self, data: TransformTypes.Images, requires_numpy: bool = False + ) -> Tuple[TransformTypes.Images, bool]: """Convert to torch/numpy, as required. Also return the original state so that after the transform, the data can be reverted to its original type. """ @@ -210,15 +218,13 @@ def pre_conv_data( data = data.detach().cpu().numpy() # type: ignore return data, input_is_numpy - def post_convert_data( - self, data: Union[torch.Tensor, np.ndarray], to_numpy: bool - ) -> Union[torch.Tensor, np.ndarray]: + def post_convert_data(self, data: TransformTypes.Images, to_numpy: bool) -> TransformTypes.Images: """Convert back to original type.""" is_numpy = isinstance(data, np.ndarray) if is_numpy and not to_numpy: data = torch.Tensor(data) if to_numpy and not is_numpy: - data = data.detach().cpu().numpy() + data = data.detach().cpu().numpy() # type: ignore return data diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index d622ce96ae..539a152a00 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -30,6 +30,7 @@ NumpyPadMode, PytorchPadMode, SkipMode, + TransformTypes, UpsampleMode, Weight, ) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 1b9e4f615d..ad4c8d3e1c 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -10,6 +10,10 @@ # limitations under the License. from enum import Enum +from typing import Dict, Hashable, Union + +import numpy as np +import torch __all__ = [ "NumpyPadMode", @@ -31,6 +35,7 @@ "InverseKeys", "CommonKeys", "ForwardMode", + "TransformTypes", ] @@ -260,3 +265,10 @@ class CommonKeys: LABEL = "label" PRED = "pred" LOSS = "loss" + + +class TransformTypes: + """Common transform types.""" + + Images = Union[torch.Tensor, np.ndarray] + ImageDict = Dict[Hashable, Images] From 1d38104eb1a5e704ddd0f5249bba132f3b5b42e4 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 27 May 2021 15:30:04 +0100 Subject: [PATCH 005/176] use types instead of bools Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 8d65365f44..3802a60fbd 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -206,24 +206,28 @@ def __call__(self, data: Any): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") def pre_conv_data( - self, data: TransformTypes.Images, requires_numpy: bool = False - ) -> Tuple[TransformTypes.Images, bool]: + self, data: TransformTypes.Images, required_type: type = torch.Tensor + ) -> Tuple[TransformTypes.Images, type]: """Convert to torch/numpy, as required. Also return the original state so that after the transform, the data can be reverted to its original type. """ - input_is_numpy = isinstance(data, np.ndarray) - if input_is_numpy and not requires_numpy: + orig_type = type(data) + assert orig_type in (torch.Tensor, np.ndarray) + + if orig_type is np.ndarray and required_type is torch.Tensor: data = torch.Tensor(data) - if requires_numpy and not input_is_numpy: + elif orig_type is torch.Tensor and required_type is np.ndarray: data = data.detach().cpu().numpy() # type: ignore - return data, input_is_numpy + return data, orig_type - def post_convert_data(self, data: TransformTypes.Images, to_numpy: bool) -> TransformTypes.Images: + def post_convert_data(self, data: TransformTypes.Images, output_type: type) -> TransformTypes.Images: """Convert back to original type.""" - is_numpy = isinstance(data, np.ndarray) - if is_numpy and not to_numpy: + current_type = type(data) + assert current_type in (torch.Tensor, np.ndarray) + + if current_type is np.ndarray and output_type is torch.Tensor: data = torch.Tensor(data) - if to_numpy and not is_numpy: + elif current_type is torch.Tensor and output_type is np.ndarray: data = data.detach().cpu().numpy() # type: ignore return data From 6ade9c3db2234c14b1f1be64bb9f832b705df15a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 27 May 2021 15:31:14 +0100 Subject: [PATCH 006/176] update var name Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index de089e8cbc..e20fb59b4b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -316,11 +316,11 @@ def __call__(self, img: TransformTypes.Images) -> TransformTypes.Images: Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - img, input_is_numpy = self.pre_conv_data(img) + img, orig_type = self.pre_conv_data(img) result = torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)).to(img.dtype) # type: ignore - return self.post_convert_data(result, input_is_numpy) + return self.post_convert_data(result, orig_type) class Resize(Transform): From 427ee2f557f892b87a0ffbf29e7504826bd0b7c7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 27 May 2021 16:25:40 +0100 Subject: [PATCH 007/176] abstract common functionality Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 17 +++-------------- monai/utils/__init__.py | 1 + monai/utils/misc.py | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 3802a60fbd..8c6c668895 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -23,6 +23,7 @@ from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple from monai.utils.enums import TransformTypes +from monai.utils.misc import convert_data_type __all__ = [ "ThreadUnsafe", @@ -211,24 +212,12 @@ def pre_conv_data( """Convert to torch/numpy, as required. Also return the original state so that after the transform, the data can be reverted to its original type. """ - orig_type = type(data) - assert orig_type in (torch.Tensor, np.ndarray) - - if orig_type is np.ndarray and required_type is torch.Tensor: - data = torch.Tensor(data) - elif orig_type is torch.Tensor and required_type is np.ndarray: - data = data.detach().cpu().numpy() # type: ignore + data, orig_type = convert_data_type(data, required_type) return data, orig_type def post_convert_data(self, data: TransformTypes.Images, output_type: type) -> TransformTypes.Images: """Convert back to original type.""" - current_type = type(data) - assert current_type in (torch.Tensor, np.ndarray) - - if current_type is np.ndarray and output_type is torch.Tensor: - data = torch.Tensor(data) - elif current_type is torch.Tensor and output_type is np.ndarray: - data = data.detach().cpu().numpy() # type: ignore + data, _ = convert_data_type(data, output_type) return data diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 539a152a00..00ef6eacef 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -38,6 +38,7 @@ from .misc import ( MAX_SEED, ImageMetaKey, + convert_data_type, copy_to_device, dtype_numpy_to_torch, dtype_torch_to_numpy, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index bd8e46d8b5..8007728629 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -22,6 +22,8 @@ import numpy as np import torch +from monai.utils.enums import TransformTypes + __all__ = [ "zip_with", "star_zip_with", @@ -42,6 +44,7 @@ "MAX_SEED", "copy_to_device", "ImageMetaKey", + "convert_data_type", ] _seed = None @@ -359,3 +362,15 @@ class ImageMetaKey: FILENAME_OR_OBJ = "filename_or_obj" PATCH_INDEX = "patch_index" + + +def convert_data_type(data, output_type: type) -> Tuple[TransformTypes.Images, type]: + """Convert to torch.Tensor/np.ndarray.""" + orig_type = type(data) + assert orig_type in (torch.Tensor, np.ndarray) + if orig_type is np.ndarray and output_type is torch.Tensor: + data = torch.Tensor(data) + elif orig_type is torch.Tensor and output_type is np.ndarray: + data = data.detach().cpu().numpy() + + return data, orig_type From 51055080a233d2297de1c9c431f8c4d1a699cdab Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 28 May 2021 16:25:04 +0100 Subject: [PATCH 008/176] torch default Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/apps/pathology/utils.py | 6 +- monai/config/__init__.py | 2 +- monai/config/type_definitions.py | 9 +-- monai/data/csv_saver.py | 7 ++- monai/data/nifti_saver.py | 5 +- monai/data/png_saver.py | 5 +- monai/handlers/utils.py | 5 +- monai/losses/dice.py | 3 +- monai/metrics/froc.py | 16 ++--- monai/metrics/hausdorff_distance.py | 5 +- monai/metrics/surface_distance.py | 5 +- monai/metrics/utils.py | 9 +-- monai/transforms/__init__.py | 11 +++- monai/transforms/croppad/array.py | 4 +- monai/transforms/croppad/batch.py | 6 +- monai/transforms/croppad/dictionary.py | 52 ++++++++--------- monai/transforms/intensity/array.py | 15 ++--- monai/transforms/intensity/dictionary.py | 55 ++++++++---------- monai/transforms/io/array.py | 11 +++- monai/transforms/io/dictionary.py | 4 +- monai/transforms/post/array.py | 3 +- monai/transforms/post/dictionary.py | 8 +-- monai/transforms/spatial/array.py | 48 +++++++-------- monai/transforms/spatial/dictionary.py | 74 ++++++++++-------------- monai/transforms/transform.py | 41 ++++++++++--- monai/transforms/utility/array.py | 56 ++++++++++-------- monai/transforms/utility/dictionary.py | 56 ++++++++---------- monai/transforms/utils.py | 3 +- monai/utils/__init__.py | 3 +- monai/utils/enums.py | 11 ++-- monai/utils/misc.py | 15 ----- monai/visualize/img2tensorboard.py | 15 ++--- tests/test_load_spacing_orientation.py | 8 ++- tests/test_remove_repeated_channel.py | 4 +- 34 files changed, 302 insertions(+), 278 deletions(-) diff --git a/monai/apps/pathology/utils.py b/monai/apps/pathology/utils.py index 0d1f530bff..9e51bcf6df 100644 --- a/monai/apps/pathology/utils.py +++ b/monai/apps/pathology/utils.py @@ -9,13 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import List import numpy as np -import torch from monai.transforms.post.array import ProbNMS from monai.utils import optional_import +from monai.utils.enums import DataObjects measure, _ = optional_import("skimage.measure") ndimage, _ = optional_import("scipy.ndimage") @@ -67,7 +67,7 @@ class PathologyProbNMS(ProbNMS): def __call__( self, - probs_map: Union[np.ndarray, torch.Tensor], + probs_map: DataObjects.Images, resolution_level: int = 0, ): """ diff --git a/monai/config/__init__.py b/monai/config/__init__.py index f1c7707d1f..fcf8f067ee 100644 --- a/monai/config/__init__.py +++ b/monai/config/__init__.py @@ -18,4 +18,4 @@ print_gpu_info, print_system_info, ) -from .type_definitions import DtypeLike, IndexSelection, KeysCollection, NdarrayTensor +from .type_definitions import DtypeLike, IndexSelection, KeysCollection diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py index daa9b10052..56a919df98 100644 --- a/monai/config/type_definitions.py +++ b/monai/config/type_definitions.py @@ -9,12 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Hashable, Iterable, TypeVar, Union +from typing import Collection, Hashable, Iterable, Union import numpy as np -import torch -__all__ = ["KeysCollection", "IndexSelection", "DtypeLike", "NdarrayTensor"] +__all__ = ["KeysCollection", "IndexSelection", "DtypeLike"] """Commonly used concepts This module provides naming and type specifications for commonly used concepts @@ -63,7 +62,3 @@ """Type of datatypes adapted from https://github.com/numpy/numpy/blob/master/numpy/typing/_dtype_like.py """ - -# Generic type which can represent either a numpy.ndarray or a torch.Tensor -# Unlike Union can create a dependence between parameter(s) / return(s) -NdarrayTensor = TypeVar("NdarrayTensor", np.ndarray, torch.Tensor) diff --git a/monai/data/csv_saver.py b/monai/data/csv_saver.py index 3701c094cd..a2880795fd 100644 --- a/monai/data/csv_saver.py +++ b/monai/data/csv_saver.py @@ -12,12 +12,13 @@ import os import warnings from collections import OrderedDict -from typing import Dict, Optional, Union +from typing import Dict, Optional import numpy as np import torch from monai.utils import ImageMetaKey as Key +from monai.utils.enums import DataObjects class CSVSaver: @@ -75,7 +76,7 @@ def finalize(self) -> None: # clear cache content after writing self.reset_cache() - def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save(self, data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None: """Save data into the cache dictionary. The metadata should have the following key: - ``'filename_or_obj'`` -- save the data corresponding to file name or object. If meta_data is None, use the default index from 0 to save data instead. @@ -93,7 +94,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] raise AssertionError self._cache_dict[save_key] = data.astype(np.float32) - def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save_batch(self, batch_data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None: """Save a batch of data into the cache dictionary. Args: diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index 8c207b0a2c..274b391e48 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -19,6 +19,7 @@ from monai.data.utils import create_file_basename from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key +from monai.utils.enums import DataObjects class NiftiSaver: @@ -96,7 +97,7 @@ def __init__( self.data_root_dir = data_root_dir self.print_log = print_log - def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save(self, data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None: """ Save data into a Nifti file. The meta_data could optionally have the following keys: @@ -160,7 +161,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] if self.print_log: print(f"file written: {path}.") - def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save_batch(self, batch_data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None: """ Save a batch of data into Nifti format files. diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index 90c1acfc1f..307c7b10b7 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -18,6 +18,7 @@ from monai.data.utils import create_file_basename from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode +from monai.utils.enums import DataObjects class PNGSaver: @@ -74,7 +75,7 @@ def __init__( self._data_index = 0 - def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save(self, data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None: """ Save data into a png file. The meta_data could optionally have the following keys: @@ -128,7 +129,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] if self.print_log: print(f"file written: {path}.") - def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def save_batch(self, batch_data: DataObjects.Images, meta_data: Optional[Dict] = None) -> None: """Save a batch of data into png format files. Args: diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index af35eaa953..df8ec235c8 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -17,6 +17,7 @@ import torch from monai.utils import ensure_tuple, exact_version, get_torch_version_tuple, optional_import +from monai.utils.enums import DataObjects idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") if TYPE_CHECKING: @@ -123,8 +124,8 @@ def string_list_all_gather(strings: List[str]) -> List[str]: def write_metrics_reports( save_dir: str, images: Optional[Sequence[str]], - metrics: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], - metric_details: Optional[Dict[str, Union[torch.Tensor, np.ndarray]]], + metrics: Optional[Dict[str, DataObjects.Images]], + metric_details: Optional[Dict[str, DataObjects.Images]], summary_ops: Optional[Union[str, Sequence[str]]], deli: str = "\t", output_type: str = "csv", diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 52757aeb66..4b73f63343 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -21,6 +21,7 @@ from monai.losses.focal_loss import FocalLoss from monai.networks import one_hot from monai.utils import LossReduction, Weight +from monai.utils.enums import DataObjects class DiceLoss(_Loss): @@ -382,7 +383,7 @@ class GeneralizedWassersteinDiceLoss(_Loss): def __init__( self, - dist_matrix: Union[np.ndarray, torch.Tensor], + dist_matrix: DataObjects.Images, weighting_mode: str = "default", reduction: Union[LossReduction, str] = LossReduction.MEAN, smooth_nr: float = 1e-5, diff --git a/monai/metrics/froc.py b/monai/metrics/froc.py index faebbbf7a6..93ceee0efc 100644 --- a/monai/metrics/froc.py +++ b/monai/metrics/froc.py @@ -10,17 +10,19 @@ # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import numpy as np import torch +from monai.utils.enums import DataObjects + def compute_fp_tp_probs( - probs: Union[np.ndarray, torch.Tensor], - y_coord: Union[np.ndarray, torch.Tensor], - x_coord: Union[np.ndarray, torch.Tensor], - evaluation_mask: Union[np.ndarray, torch.Tensor], + probs: DataObjects.Images, + y_coord: DataObjects.Images, + x_coord: DataObjects.Images, + evaluation_mask: DataObjects.Images, labels_to_exclude: Optional[List] = None, resolution_level: int = 0, ): @@ -77,8 +79,8 @@ def compute_fp_tp_probs( def compute_froc_curve_data( - fp_probs: Union[np.ndarray, torch.Tensor], - tp_probs: Union[np.ndarray, torch.Tensor], + fp_probs: DataObjects.Images, + tp_probs: DataObjects.Images, num_targets: int, num_images: int, ): diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 9617c0365a..80fbc3d641 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -17,6 +17,7 @@ from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background from monai.utils import MetricReduction +from monai.utils.enums import DataObjects __all__ = ["HausdorffDistanceMetric", "compute_hausdorff_distance", "compute_percent_hausdorff_distance"] @@ -98,8 +99,8 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): def compute_hausdorff_distance( - y_pred: Union[np.ndarray, torch.Tensor], - y: Union[np.ndarray, torch.Tensor], + y_pred: DataObjects.Images, + y: DataObjects.Images, include_background: bool = False, distance_metric: str = "euclidean", percentile: Optional[float] = None, diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index d4b2a84572..bb7a39ec61 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -17,6 +17,7 @@ from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background from monai.utils import MetricReduction +from monai.utils.enums import DataObjects class SurfaceDistanceMetric: @@ -88,8 +89,8 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): def compute_average_surface_distance( - y_pred: Union[np.ndarray, torch.Tensor], - y: Union[np.ndarray, torch.Tensor], + y_pred: DataObjects.Images, + y: DataObjects.Images, include_background: bool = False, symmetric: bool = False, distance_metric: str = "euclidean", diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index b7320918d8..0acd934ada 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -18,6 +18,7 @@ from monai.transforms.croppad.array import SpatialCrop from monai.transforms.utils import generate_spatial_bounding_box from monai.utils import MetricReduction, optional_import +from monai.utils.enums import DataObjects binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") @@ -27,8 +28,8 @@ def ignore_background( - y_pred: Union[np.ndarray, torch.Tensor], - y: Union[np.ndarray, torch.Tensor], + y_pred: DataObjects.Images, + y: DataObjects.Images, ): """ This function is used to remove background (the first channel) for `y_pred` and `y`. @@ -105,8 +106,8 @@ def do_metric_reduction( def get_mask_edges( - seg_pred: Union[np.ndarray, torch.Tensor], - seg_gt: Union[np.ndarray, torch.Tensor], + seg_pred: DataObjects.Images, + seg_gt: DataObjects.Images, label_idx: int = 1, crop: bool = True, ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ae8b06647e..2055a85edd 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -293,7 +293,16 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform +from .transform import ( + MapTransform, + NumpyTransform, + Randomizable, + RandomizableTransform, + ThreadUnsafe, + Transform, + apply_transform, + convert_data_type, +) from .utility.array import ( AddChannel, AddExtremePointsChannel, diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 7557565b83..91443eaa91 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -18,7 +18,6 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import numpy as np -import torch from monai.config import IndexSelection from monai.data.utils import get_random_patch, get_valid_patch_size @@ -32,6 +31,7 @@ weighted_patch_samples, ) from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple +from monai.utils.enums import DataObjects __all__ = [ "SpatialPad", @@ -259,7 +259,7 @@ def __init__( # convert to slices self.slices = [slice(s, e) for s, e in zip(roi_start_np, roi_end_np)] - def __call__(self, img: Union[np.ndarray, torch.Tensor]): + def __call__(self, img: DataObjects.Images): """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index 3ecabc387b..336f48593b 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -14,7 +14,7 @@ """ from copy import deepcopy -from typing import Any, Dict, Hashable, Union +from typing import Any, Union import numpy as np import torch @@ -24,7 +24,7 @@ from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad from monai.transforms.inverse import InvertibleTransform from monai.transforms.utility.array import ToTensor -from monai.utils.enums import InverseKeys, Method, NumpyPadMode +from monai.utils.enums import DataObjects, InverseKeys, Method, NumpyPadMode __all__ = [ "PadListDataCollate", @@ -116,7 +116,7 @@ def __call__(self, batch: Any): return list_data_collate(batch) @staticmethod - def inverse(data: dict) -> Dict[Hashable, np.ndarray]: + def inverse(data: dict) -> DataObjects.Dict: if not isinstance(data, dict): raise RuntimeError("Inverse can only currently be applied on dictionaries.") diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 4ec14e0a7d..2c36f23857 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -19,7 +19,7 @@ from enum import Enum from itertools import chain from math import ceil, floor -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import numpy as np @@ -45,7 +45,7 @@ ) from monai.utils import ImageMetaKey as Key from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple -from monai.utils.enums import InverseKeys +from monai.utils.enums import DataObjects, InverseKeys __all__ = [ "NumpyPadModeSequence", @@ -130,14 +130,14 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = SpatialPad(spatial_size, method) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -198,14 +198,14 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = BorderPad(spatial_border=spatial_border) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -264,14 +264,14 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padder = DivisiblePad(k=k) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, m in self.key_iterator(d, self.mode): self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m}) d[key] = self.padder(d[key], mode=m) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -326,14 +326,14 @@ def __init__( super().__init__(keys, allow_missing_keys) self.cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -373,7 +373,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.cropper = CenterSpatialCrop(roi_size) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): orig_size = d[key].shape[1:] @@ -381,7 +381,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.push_transform(d, key, orig_size=orig_size) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -421,7 +421,7 @@ def __init__( super().__init__(keys, allow_missing_keys=allow_missing_keys) self.roi_scale = roi_scale - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) img_size = data[self.keys[0]].shape[1:] ndim = len(img_size) @@ -481,7 +481,7 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key if self._size is None: @@ -496,7 +496,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = cropper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -575,7 +575,7 @@ def __init__( self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: img_size = data[self.keys[0]].shape[1:] ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] @@ -656,7 +656,7 @@ def set_random_state( def randomize(self, data: Optional[Any] = None) -> None: pass - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: ret = [] for i in range(self.num_samples): d = dict(data) @@ -731,7 +731,7 @@ def __init__( mode=mode, ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key]) d[self.start_coord_key] = box_start @@ -741,7 +741,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.cropper.crop_pad(d[key], box_start, box_end) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -825,12 +825,12 @@ def randomize(self, weight_map: np.ndarray) -> None: spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: d = dict(data) self.randomize(d[self.w_key]) _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) - results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] + results: List[DataObjects.Dict] = [{} for _ in range(self.num_samples)] for key in self.key_iterator(d): img = d[key] if img.shape[1:] != d[self.w_key].shape[1:]: @@ -958,7 +958,7 @@ def randomize( self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None @@ -970,7 +970,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n raise AssertionError if self.centers is None: raise AssertionError - results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] + results: List[DataObjects.Dict] = [{} for _ in range(self.num_samples)] for i, center in enumerate(self.centers): for key in self.key_iterator(d): @@ -1019,7 +1019,7 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, m in self.key_iterator(d, self.mode): orig_size = d[key].shape[1:] @@ -1034,7 +1034,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1094,7 +1094,7 @@ def __init__( self.bbox = BoundingRect(select_fn=select_fn) self.bbox_key_postfix = bbox_key_postfix - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: """ See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index cf4e33b229..6688846637 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -31,6 +31,7 @@ ensure_tuple_rep, ensure_tuple_size, ) +from monai.utils.enums import DataObjects __all__ = [ "RandGaussianNoise", @@ -81,7 +82,7 @@ def randomize(self, im_shape: Sequence[int]) -> None: super().randomize(None) self._noise = self.R.normal(self.mean, self.R.uniform(0, self.std), size=im_shape) - def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -136,7 +137,7 @@ def __init__( self._noise1 = None self._noise2 = None - def _add_noise(self, img: Union[torch.Tensor, np.ndarray], mean: float, std: float): + def _add_noise(self, img: DataObjects.Images, mean: float, std: float): im_shape = img.shape _std = self.R.uniform(0, std) if self.sample_std else std self._noise1 = self.R.normal(mean, _std, size=im_shape) @@ -146,7 +147,7 @@ def _add_noise(self, img: Union[torch.Tensor, np.ndarray], mean: float, std: flo dtype = dtype_torch_to_numpy(img.dtype) if isinstance(img, torch.Tensor) else img.dtype return np.sqrt((img + self._noise1.astype(dtype)) ** 2 + self._noise2.astype(dtype) ** 2) - def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -1174,7 +1175,7 @@ def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_te RandomizableTransform.__init__(self, prob=prob) - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: # randomize application and possibly alpha self._randomize(None) @@ -1226,7 +1227,7 @@ def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None: self.as_tensor_output = as_tensor_output self._device = torch.device("cpu") - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: n_dims = len(img.shape[1:]) # convert to ndarray to work with np.fft @@ -1243,7 +1244,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, img = self._inv_shift_fourier(k, n_dims) return torch.Tensor(img).to(_device or self._device) if self.as_tensor_output else img - def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: + def _shift_fourier(self, x: DataObjects.Images, n_dims: int) -> np.ndarray: """ Applies fourier transform and shifts its output. Only the spatial dimensions get transformed. @@ -1254,7 +1255,7 @@ def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np. out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) return out - def _inv_shift_fourier(self, k: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: + def _inv_shift_fourier(self, k: DataObjects.Images, n_dims: int) -> np.ndarray: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 07da1e651e..4199fe0c44 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -16,7 +16,7 @@ """ from collections.abc import Iterable -from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -40,6 +40,7 @@ ) from monai.transforms.transform import MapTransform, RandomizableTransform from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size +from monai.utils.enums import DataObjects __all__ = [ "RandGaussianNoised", @@ -148,7 +149,7 @@ def randomize(self, im_shape: Sequence[int]) -> None: for m in self.mean: self._noise.append(self.R.normal(m, self.R.uniform(0, self.std), size=im_shape)) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) image_shape = d[self.keys[0]].shape # image shape from the first data key @@ -203,9 +204,7 @@ def __init__( RandomizableTransform.__init__(self, global_prob) self.rand_rician_noise = RandRicianNoise(prob, mean, std, channel_wise, relative, sample_std) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) super().randomize(None) if not self._do_transform: @@ -231,7 +230,7 @@ def __init__(self, keys: KeysCollection, offset: float, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.shifter = ShiftIntensity(offset) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.shifter(d[key]) @@ -275,7 +274,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if not self._do_transform: @@ -314,7 +313,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.shifter = StdShiftIntensity(factor, nonzero, channel_wise, dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.shifter(d[key]) @@ -366,7 +365,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if not self._do_transform: @@ -406,7 +405,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensity(minv, maxv, factor) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -451,7 +450,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if not self._do_transform: @@ -496,7 +495,7 @@ def __init__( def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if not self._do_transform: @@ -537,7 +536,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.normalizer(d[key]) @@ -568,7 +567,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.filter = ThresholdIntensity(threshold, above, cval) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.filter(d[key]) @@ -603,7 +602,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -628,7 +627,7 @@ def __init__(self, keys: KeysCollection, gamma: float, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.adjuster = AdjustContrast(gamma) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.adjuster(d[key]) @@ -678,7 +677,7 @@ def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) self.gamma_value = self.R.uniform(low=self.gamma[0], high=self.gamma[1]) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if self.gamma_value is None: @@ -721,7 +720,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.scaler(d[key]) @@ -757,7 +756,7 @@ def __init__( self.converter = MaskIntensity(mask_data) self.mask_key = mask_key if mask_data is None else None - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key], d[self.mask_key]) if self.mask_key is not None else self.converter(d[key]) @@ -790,7 +789,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = GaussianSmooth(sigma, approx=approx) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -837,7 +836,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1]) self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1]) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if not self._do_transform: @@ -880,7 +879,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = GaussianSharpen(sigma1, sigma2, alpha, approx=approx) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1006,7 +1005,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.floating_control_points[i - 1], self.floating_control_points[i + 1] ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) self.randomize() if not self._do_transform: @@ -1061,9 +1060,7 @@ def __init__( self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() self.as_tensor_output = as_tensor_output - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) self._randomize(None) @@ -1088,7 +1085,7 @@ def _randomize(self, _: Any) -> None: super().randomize(None) self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) - def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray: + def _to_numpy(self, d: DataObjects.Images) -> np.ndarray: if isinstance(d, torch.Tensor): d_numpy: np.ndarray = d.cpu().detach().numpy() return d_numpy @@ -1121,9 +1118,7 @@ def __init__( MapTransform.__init__(self, keys, allow_missing_keys) self.transform = GibbsNoise(alpha, as_tensor_output) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) for key in self.key_iterator(d): diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index eb8aff0656..2933eaa6e3 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -26,6 +26,7 @@ from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode, ensure_tuple, optional_import +from monai.utils.enums import DataObjects nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -70,6 +71,7 @@ def __init__( reader: Optional[Union[ImageReader, str]] = None, image_only: bool = False, dtype: DtypeLike = np.float32, + as_tensor: bool = True, *args, **kwargs, ) -> None: @@ -81,6 +83,7 @@ def __init__( "PILReader", "ITKReader", "NumpyReader". image_only: if True return only the image volume, otherwise return image data array and header dict. dtype: if not None convert the loaded image to this data type. + as_tensor: output as `torch.Tensor`, defaults to `True`. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. @@ -108,6 +111,7 @@ def __init__( self.image_only = image_only self.dtype = dtype + self.as_tensor = as_tensor def register(self, reader: ImageReader) -> List[ImageReader]: """ @@ -152,9 +156,14 @@ def __call__( ) img = reader.read(filename) + img_array: DataObjects.Images img_array, meta_data = reader.get_data(img) img_array = img_array.astype(self.dtype) + # convert to desired output type + if self.as_tensor: + img_array = torch.Tensor(img_array) + if self.image_only: return img_array meta_data[Key.FILENAME_OR_OBJ] = ensure_tuple(filename)[0] @@ -274,7 +283,7 @@ def __init__( self.save_batch = save_batch - def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None): + def __call__(self, img: DataObjects.Images, meta_data: Optional[Dict] = None): """ Args: img: target data content that save into file. diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 39649e3858..c3ecddb4e1 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -61,6 +61,7 @@ def __init__( meta_key_postfix: str = "meta_dict", overwriting: bool = False, image_only: bool = False, + as_tensor: bool = True, allow_missing_keys: bool = False, *args, **kwargs, @@ -85,12 +86,13 @@ def __init__( default is False, which will raise exception if encountering existing key. image_only: if True return dictionary containing just only the image volumes, otherwise return dictionary containing image data array and header dict per input key. + as_tensor: output as `torch.Tensor`, defaults to `True`. allow_missing_keys: don't raise exception if key is missing. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. """ super().__init__(keys, allow_missing_keys) - self._loader = LoadImage(reader, image_only, dtype, *args, **kwargs) + self._loader = LoadImage(reader, image_only, dtype, as_tensor, *args, **kwargs) if not isinstance(meta_key_postfix, str): raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index dd4f7afd9d..0a7cb8c716 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -25,6 +25,7 @@ from monai.transforms.transform import Transform from monai.transforms.utils import get_largest_connected_component_mask from monai.utils import ensure_tuple +from monai.utils.enums import DataObjects __all__ = [ "Activations", @@ -485,7 +486,7 @@ def __init__( def __call__( self, - prob_map: Union[np.ndarray, torch.Tensor], + prob_map: DataObjects.Images, ): """ prob_map: the input probabilities map, it must have shape (H[, W, ...]). diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 16e21a9881..883ac8c02f 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -17,7 +17,7 @@ import warnings from copy import deepcopy -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Union +from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Union import numpy as np import torch @@ -41,7 +41,7 @@ from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode from monai.utils import ensure_tuple, ensure_tuple_rep -from monai.utils.enums import InverseKeys +from monai.utils.enums import DataObjects, InverseKeys __all__ = [ "Activationsd", @@ -407,7 +407,7 @@ def __init__( box_size=box_size, ) - def __call__(self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]): + def __call__(self, data: DataObjects.Mapping): d = dict(data) for key in self.key_iterator(d): d[key] = self.prob_nms(d[key]) @@ -513,7 +513,7 @@ def __init__( self.post_func = ensure_tuple_rep(post_func, len(self.keys)) self._totensor = ToTensor() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for ( key, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index e20fb59b4b..e6dc90d742 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -45,7 +45,7 @@ issequenceiterable, optional_import, ) -from monai.utils.enums import TransformTypes +from monai.utils.enums import DataObjects nib, _ = optional_import("nibabel") @@ -311,7 +311,7 @@ class Flip(Transform): def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: self.spatial_axis = spatial_axis - def __call__(self, img: TransformTypes.Images) -> TransformTypes.Images: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), @@ -778,7 +778,7 @@ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int] RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, img: TransformTypes.Images) -> TransformTypes.Images: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), @@ -804,11 +804,11 @@ def __init__(self, prob: float = 0.1) -> None: RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None - def randomize(self, data: TransformTypes.Images) -> None: + def randomize(self, data: DataObjects.Images) -> None: super().randomize(None) self._axis = self.R.randint(data.ndim - 1) - def __call__(self, img: TransformTypes.Images) -> TransformTypes.Images: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), @@ -956,7 +956,7 @@ def __init__( scale_params: Optional[Union[Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, - affine: Optional[Union[np.ndarray, torch.Tensor]] = None, + affine: Optional[DataObjects.Images] = None, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params @@ -971,8 +971,8 @@ def __init__( def __call__( self, spatial_size: Optional[Sequence[int]] = None, - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: + grid: Optional[DataObjects.Images] = None, + ) -> Tuple[DataObjects.Images, DataObjects.Images]: """ Args: spatial_size: output grid size. @@ -1063,7 +1063,7 @@ def __init__( self.as_tensor_output = as_tensor_output self.device = device - self.affine: Optional[Union[np.ndarray, torch.Tensor]] = None + self.affine: Optional[DataObjects.Images] = None def _get_rand_param(self, param_range, add_scalar: float = 0.0): out_param = [] @@ -1085,8 +1085,8 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, spatial_size: Optional[Sequence[int]] = None, - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + grid: Optional[DataObjects.Images] = None, + ) -> DataObjects.Images: """ Args: spatial_size: output grid size. @@ -1107,7 +1107,7 @@ def __call__( grid, self.affine = affine_grid(spatial_size, grid) return grid - def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: + def get_transformation_matrix(self) -> Optional[DataObjects.Images]: """Get the most recently applied transformation matrix""" return self.affine @@ -1191,11 +1191,11 @@ def __init__( def __call__( self, - img: Union[np.ndarray, torch.Tensor], - grid: Optional[Union[np.ndarray, torch.Tensor]] = None, + img: DataObjects.Images, + grid: Optional[DataObjects.Images] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> DataObjects.Images: """ Args: img: shape must be (num_channels, H, W[, D]). @@ -1314,11 +1314,11 @@ def __init__( def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: DataObjects.Images, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: + ) -> Tuple[DataObjects.Images, DataObjects.Images]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1465,11 +1465,11 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: DataObjects.Images, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> DataObjects.Images: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1590,11 +1590,11 @@ def randomize(self, spatial_size: Sequence[int]) -> None: def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: DataObjects.Images, spatial_size: Optional[Union[Tuple[int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> DataObjects.Images: """ Args: img: shape must be (num_channels, H, W), @@ -1717,11 +1717,11 @@ def randomize(self, grid_size: Sequence[int]) -> None: def __call__( self, - img: Union[np.ndarray, torch.Tensor], + img: DataObjects.Images, spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> DataObjects.Images: """ Args: img: shape must be (num_channels, H, W, D), @@ -1771,7 +1771,7 @@ def __init__( """ self.spatial_channels = spatial_channels - def __call__(self, img: Union[np.ndarray, torch.Tensor]): + def __call__(self, img: DataObjects.Images): """ Args: img: data to be transformed, assuming `img` is channel first. diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index de56d6717b..d9bd4a1599 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -53,7 +53,7 @@ ensure_tuple_rep, fall_back_tuple, ) -from monai.utils.enums import InverseKeys, TransformTypes +from monai.utils.enums import DataObjects, InverseKeys from monai.utils.module import optional_import nib, _ = optional_import("nibabel") @@ -239,7 +239,7 @@ def __call__( meta_data["affine"] = new_affine return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) @@ -348,7 +348,7 @@ def __call__( d[meta_key]["affine"] = new_affine return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -388,14 +388,14 @@ def __init__( super().__init__(keys, allow_missing_keys) self.rotator = Rotate90(k, spatial_axes) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.rotator(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) @@ -454,7 +454,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Mapping: self.randomize() d = dict(data) @@ -465,7 +465,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. self.push_transform(d, key, extra_info={"rand_k": self._rand_k}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -521,7 +521,7 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.resizer = Resize(spatial_size=spatial_size) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): self.push_transform( @@ -535,7 +535,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -616,9 +616,7 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): orig_size = d[key].shape[1:] @@ -635,7 +633,7 @@ def __call__( ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -751,9 +749,7 @@ def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) self.rand_affine.randomize() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) self.randomize() @@ -792,14 +788,14 @@ def __call__( return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # if transform was not performed and spatial size is None, nothing to do. if not transform[InverseKeys.DO_TRANSFORM] and self.rand_affine.spatial_size is None: - out: Union[np.ndarray, torch.Tensor] = d[key] + out: DataObjects.Images = d[key] else: orig_size = transform[InverseKeys.ORIG_SIZE] # Create inverse transform @@ -915,9 +911,7 @@ def randomize(self, spatial_size: Sequence[int]) -> None: super().randomize(None) self.rand_2d_elastic.randomize(spatial_size) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:]) @@ -1035,9 +1029,7 @@ def randomize(self, grid_size: Sequence[int]) -> None: super().randomize(None) self.rand_3d_elastic.randomize(grid_size) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:]) @@ -1078,14 +1070,14 @@ def __init__( super().__init__(keys, allow_missing_keys) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: + def __call__(self, data: DataObjects.Dict) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.flipper(d[key]) return d - def inverse(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: + def inverse(self, data: DataObjects.Dict) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) @@ -1124,7 +1116,7 @@ def __init__( self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: + def __call__(self, data: DataObjects.Dict) -> DataObjects.Dict: self.randomize(None) d = dict(data) for key in self.key_iterator(d): @@ -1133,7 +1125,7 @@ def __call__(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: self.push_transform(d, key) return d - def inverse(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: + def inverse(self, data: DataObjects.Dict) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1165,11 +1157,11 @@ def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None - def randomize(self, data: TransformTypes.Images) -> None: + def randomize(self, data: DataObjects.Images) -> None: super().randomize(None) self._axis = self.R.randint(data.ndim - 1) - def __call__(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: + def __call__(self, data: DataObjects.Dict) -> DataObjects.Dict: self.randomize(data=data[self.keys[0]]) flipper = Flip(spatial_axis=self._axis) @@ -1180,7 +1172,7 @@ def __call__(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: self.push_transform(d, key, extra_info={"axis": self._axis}) return d - def inverse(self, data: TransformTypes.ImageDict) -> TransformTypes.ImageDict: + def inverse(self, data: DataObjects.Dict) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1241,7 +1233,7 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype @@ -1268,7 +1260,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) @@ -1375,7 +1367,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: self.randomize() d = dict(data) angle: Union[Sequence[float], float] = self.x if d[self.keys[0]].ndim == 3 else (self.x, self.y, self.z) @@ -1411,7 +1403,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) @@ -1484,7 +1476,7 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.zoomer = Zoom(zoom=zoom, keep_size=keep_size) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners @@ -1506,7 +1498,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1594,7 +1586,7 @@ def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: # match the spatial dim of first item self.randomize() d = dict(data) @@ -1629,7 +1621,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda ) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1675,9 +1667,7 @@ def __init__(self, keys: KeysCollection, spatial_channels: Sequence[int], allow_ super().__init__(keys, allow_missing_keys) self.add_coordinate_channels = AddCoordinateChannels(spatial_channels) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) for key in self.key_iterator(d): d[key] = self.add_coordinate_channels(d[key]) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 8c6c668895..c3c42fff13 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Generator, Hashable, Iterable, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -22,8 +22,7 @@ from monai import transforms from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple -from monai.utils.enums import TransformTypes -from monai.utils.misc import convert_data_type +from monai.utils.enums import DataObjects __all__ = [ "ThreadUnsafe", @@ -32,9 +31,23 @@ "RandomizableTransform", "Transform", "MapTransform", + "convert_data_type", + "NumpyTransform", ] +def convert_data_type(data, output_type: type) -> Tuple[DataObjects.Images, type]: + """Convert to torch.Tensor/np.ndarray.""" + orig_type = type(data) + assert orig_type in (torch.Tensor, np.ndarray), f"Expected types {(torch.Tensor, np.ndarray)}, got {orig_type}" + if orig_type is np.ndarray and output_type is torch.Tensor: + data = torch.Tensor(data) + elif orig_type is torch.Tensor and output_type is np.ndarray: + data = data.detach().cpu().numpy() + + return data, orig_type + + def apply_transform(transform: Callable, data, map_items: bool = True): """ Transform `data` with `transform`. @@ -178,6 +191,9 @@ class Transform(ABC): :py:class:`monai.transforms.Compose` """ + array_class: Union[Type[torch.Tensor], Type[np.ndarray]] + array_class = torch.Tensor + @abstractmethod def __call__(self, data: Any): """ @@ -206,21 +222,28 @@ def __call__(self, data: Any): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def pre_conv_data( - self, data: TransformTypes.Images, required_type: type = torch.Tensor - ) -> Tuple[TransformTypes.Images, type]: + def pre_conv_data(self, data: DataObjects.Images) -> Tuple[DataObjects.Images, type]: """Convert to torch/numpy, as required. Also return the original state so that after the transform, the data can be reverted to its original type. """ - data, orig_type = convert_data_type(data, required_type) + data, orig_type = convert_data_type(data, self.array_class) return data, orig_type - def post_convert_data(self, data: TransformTypes.Images, output_type: type) -> TransformTypes.Images: + @staticmethod + def post_convert_data(data: DataObjects.Images, output_type: type) -> DataObjects.Images: """Convert back to original type.""" data, _ = convert_data_type(data, output_type) return data +class NumpyTransform(Transform): + """Most transforms use torch. Transforms that inherit from this class, however, use numpy under the hood. + This means that if the input image is `torch.Tensor`, it will be converted to numpy and then reverted + at the end.""" + + array_class = np.ndarray + + class RandomizableTransform(Randomizable, Transform): """ An interface for handling random state locally, currently based on a class variable `R`, @@ -331,7 +354,7 @@ def __call__(self, data): def key_iterator( self, - data: Dict[Hashable, Any], + data: DataObjects.Dict, *extra_iterables: Optional[Iterable], ) -> Generator: """ diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index fdb0abe785..83cbab7a9c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -21,10 +21,11 @@ import numpy as np import torch -from monai.config import DtypeLike, NdarrayTensor -from monai.transforms.transform import Randomizable, Transform +from monai.config import DtypeLike +from monai.transforms.transform import Randomizable, Transform, convert_data_type from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices from monai.utils import ensure_tuple, min_version, optional_import +from monai.utils.enums import DataObjects PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -66,7 +67,7 @@ class Identity(Transform): """ - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> np.ndarray: """ Apply the transform to `img`. """ @@ -94,11 +95,13 @@ def __init__(self, channel_dim: int = -1) -> None: raise AssertionError("invalid channel dimension.") self.channel_dim = channel_dim - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - return np.moveaxis(img, self.channel_dim, 0) + img, orig_type = self.pre_conv_data(img) + img = torch.moveaxis(img, self.channel_dim, 0) # type: ignore + return self.post_convert_data(img, orig_type) class AsChannelLast(Transform): @@ -121,11 +124,13 @@ def __init__(self, channel_dim: int = 0) -> None: raise AssertionError("invalid channel dimension.") self.channel_dim = channel_dim - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - return np.moveaxis(img, self.channel_dim, -1) + img, orig_type = self.pre_conv_data(img) + img = torch.moveaxis(img, self.channel_dim, -1) # type: ignore + return self.post_convert_data(img, orig_type) class AddChannel(Transform): @@ -142,7 +147,7 @@ class AddChannel(Transform): transforms. """ - def __call__(self, img: NdarrayTensor): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -158,7 +163,7 @@ class EnsureChannelFirst(Transform): """ - def __call__(self, img: np.ndarray, meta_dict: Optional[Dict] = None): + def __call__(self, img: DataObjects.Images, meta_dict: Optional[Dict] = None) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -189,11 +194,13 @@ def __init__(self, repeats: int) -> None: raise AssertionError("repeats count must be greater than 0.") self.repeats = repeats - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ - return np.repeat(img, self.repeats, 0) + img, orig_type = self.pre_conv_data(img) + img = torch.repeat_interleave(img, self.repeats, 0) # type: ignore + return self.post_convert_data(img, orig_type) class RemoveRepeatedChannel(Transform): @@ -212,14 +219,16 @@ def __init__(self, repeats: int) -> None: self.repeats = repeats - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ if np.shape(img)[0] < 2: raise AssertionError("Image must have more than one channel") - return np.array(img[:: self.repeats, :]) + img, orig_type = self.pre_conv_data(img) + img = torch.Tensor(img[:: self.repeats, :]) # type: ignore + return self.post_convert_data(img, orig_type) class SplitChannel(Transform): @@ -238,7 +247,7 @@ class SplitChannel(Transform): def __init__(self, channel_dim: Optional[int] = None) -> None: self.channel_dim = channel_dim - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> List[Union[np.ndarray, torch.Tensor]]: + def __call__(self, img: DataObjects.Images) -> List[DataObjects.Images]: if self.channel_dim is None: # automatically select the default channel dim based on data type if isinstance(img, torch.Tensor): @@ -275,8 +284,8 @@ def __init__(self, dtype=np.float32) -> None: self.dtype = dtype def __call__( - self, img: Union[np.ndarray, torch.Tensor], dtype: Optional[Union[DtypeLike, torch.dtype]] = None - ) -> Union[np.ndarray, torch.Tensor]: + self, img: DataObjects.Images, dtype: Optional[Union[DtypeLike, torch.dtype]] = None + ) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is a numpy array or PyTorch Tensor. @@ -317,10 +326,7 @@ def __call__(self, img) -> np.ndarray: """ Apply the transform to `img` and make it contiguous. """ - if isinstance(img, torch.Tensor): - img = img.detach().cpu().numpy() # type: ignore - elif has_cp and isinstance(img, cp_ndarray): - img = cp.asnumpy(img) # type: ignore + img, _ = convert_data_type(img, np.ndarray) return np.ascontiguousarray(img) @@ -388,7 +394,7 @@ def __init__(self, dim: Optional[int] = 0) -> None: raise TypeError(f"dim must be None or a int but is {type(dim).__name__}.") self.dim = dim - def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: numpy arrays with required dimension `dim` removed @@ -452,14 +458,14 @@ def __init__( def __call__( self, - img: NdarrayTensor, + img: DataObjects.Images, prefix: Optional[str] = None, data_type: Optional[bool] = None, data_shape: Optional[bool] = None, value_range: Optional[bool] = None, data_value: Optional[bool] = None, additional_info: Optional[Callable] = None, - ) -> NdarrayTensor: + ) -> DataObjects.Images: """ Apply the transform to `img`, optionally take arguments similar to the class constructor. """ @@ -508,7 +514,7 @@ def __init__(self, delay_time: float = 0.0) -> None: super().__init__() self.delay_time: float = delay_time - def __call__(self, img: NdarrayTensor, delay_time: Optional[float] = None) -> NdarrayTensor: + def __call__(self, img: DataObjects.Images, delay_time: Optional[float] = None) -> DataObjects.Images: """ Args: img: data remain unchanged throughout this transform. @@ -546,7 +552,7 @@ def __init__(self, func: Optional[Callable] = None) -> None: raise TypeError(f"func must be None or callable but is {type(func).__name__}.") self.func = func - def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable] = None): + def __call__(self, img: DataObjects.Images, func: Optional[Callable] = None): """ Apply `self.func` to `img`. diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index dd4fd381d9..9cf62a1227 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -23,7 +23,7 @@ import numpy as np import torch -from monai.config import DtypeLike, KeysCollection, NdarrayTensor +from monai.config import DtypeLike, KeysCollection from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utility.array import ( @@ -53,7 +53,7 @@ ) from monai.transforms.utils import extreme_points_to_image, get_extreme_points from monai.utils import ensure_tuple, ensure_tuple_rep -from monai.utils.enums import InverseKeys +from monai.utils.enums import DataObjects, InverseKeys __all__ = [ "AddChannelD", @@ -165,9 +165,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.identity = Identity() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) for key in self.key_iterator(d): d[key] = self.identity(d[key]) @@ -190,7 +188,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_ke super().__init__(keys, allow_missing_keys) self.converter = AsChannelFirst(channel_dim=channel_dim) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -213,7 +211,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_key super().__init__(keys, allow_missing_keys) self.converter = AsChannelLast(channel_dim=channel_dim) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -235,7 +233,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.adder = AddChannel() - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.adder(d[key]) @@ -272,7 +270,7 @@ def __init__( self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - def __call__(self, data) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, meta_key, meta_key_postfix in zip(self.keys, self.meta_keys, self.meta_key_postfix): d[key] = self.adjuster(d[key], d[meta_key or f"{key}_{meta_key_postfix}"]) @@ -295,7 +293,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.repeater = RepeatChannel(repeats) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.repeater(d[key]) @@ -318,7 +316,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.repeater = RemoveRepeatedChannel(repeats) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.repeater(d[key]) @@ -358,9 +356,7 @@ def __init__( self.output_postfixes = output_postfixes self.splitter = SplitChannel(channel_dim=channel_dim) - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) for key in self.key_iterator(d): rets = self.splitter(d[key]) @@ -400,9 +396,7 @@ def __init__( self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.converter = CastToType() - def __call__( - self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]] - ) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]: + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) for key, dtype in self.key_iterator(d, self.dtype): d[key] = self.converter(d[key], dtype=dtype) @@ -425,14 +419,14 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToTensor() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.converter(d[key]) return d - def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): # Create inverse transform @@ -459,7 +453,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToNumpy() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -481,7 +475,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToCupy() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -503,7 +497,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No super().__init__(keys, allow_missing_keys) self.converter = ToPIL() - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -521,7 +515,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.transform = Transpose(indices) - def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.transform(d[key]) @@ -530,7 +524,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: self.push_transform(d, key, extra_info={"indices": indices}) return d - def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -582,7 +576,7 @@ def __init__(self, keys: KeysCollection, dim: int = 0, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.converter = SqueezeDim(dim=dim) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -640,7 +634,7 @@ def __init__( self.logger_handler = logger_handler self.printer = DataStats(logger_handler=logger_handler) - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator( d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info @@ -678,7 +672,7 @@ def __init__( self.delay_time = ensure_tuple_rep(delay_time, len(self.keys)) self.delayer = SimulateDelay() - def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, delay_time in self.key_iterator(d, self.delay_time): d[key] = self.delayer(d[key], delay_time=delay_time) @@ -879,7 +873,7 @@ def __init__( # pytype: disable=annotation-type-mismatch super().__init__(keys, allow_missing_keys) self.converter = LabelToMask(select_labels=select_labels, merge_channels=merge_channels) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -923,7 +917,7 @@ def __init__( self.image_key = image_key self.converter = FgBgToIndices(image_threshold, output_shape) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) image = d[self.image_key] if self.image_key else None for key in self.key_iterator(d): @@ -947,7 +941,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): super().__init__(keys, allow_missing_keys) self.converter = ConvertToMultiChannelBasedOnBratsClasses() - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1126,7 +1120,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels, dtype=dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.mapper(d[key]) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 911270b3fd..e45bd25bcd 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -34,6 +34,7 @@ min_version, optional_import, ) +from monai.utils.enums import DataObjects measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -92,7 +93,7 @@ def in_bounds(x: float, y: float, margin: float, maxx: float, maxy: float) -> bo return bool(margin <= x < (maxx - margin) and margin <= y < (maxy - margin)) -def is_empty(img: Union[np.ndarray, torch.Tensor]) -> bool: +def is_empty(img: DataObjects.Images) -> bool: """ Returns True if `img` is empty, that is its maximum value is not greater than its minimum. """ diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 00ef6eacef..791e00f45b 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -18,6 +18,7 @@ BlendMode, ChannelMatching, CommonKeys, + DataObjects, ForwardMode, GridSampleMode, GridSamplePadMode, @@ -30,7 +31,6 @@ NumpyPadMode, PytorchPadMode, SkipMode, - TransformTypes, UpsampleMode, Weight, ) @@ -38,7 +38,6 @@ from .misc import ( MAX_SEED, ImageMetaKey, - convert_data_type, copy_to_device, dtype_numpy_to_torch, dtype_torch_to_numpy, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index ad4c8d3e1c..f1d90b6063 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -10,7 +10,7 @@ # limitations under the License. from enum import Enum -from typing import Dict, Hashable, Union +from typing import Any, Dict, Hashable, Mapping, Union import numpy as np import torch @@ -35,7 +35,7 @@ "InverseKeys", "CommonKeys", "ForwardMode", - "TransformTypes", + "DataObjects", ] @@ -267,8 +267,9 @@ class CommonKeys: LOSS = "loss" -class TransformTypes: - """Common transform types.""" +class DataObjects: + """Common classes used for arrays/tensors and then their usage in dict/mappings.""" Images = Union[torch.Tensor, np.ndarray] - ImageDict = Dict[Hashable, Images] + Dict = Dict[Hashable, Any] + Mapping = Mapping[Hashable, Any] diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 8007728629..bd8e46d8b5 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -22,8 +22,6 @@ import numpy as np import torch -from monai.utils.enums import TransformTypes - __all__ = [ "zip_with", "star_zip_with", @@ -44,7 +42,6 @@ "MAX_SEED", "copy_to_device", "ImageMetaKey", - "convert_data_type", ] _seed = None @@ -362,15 +359,3 @@ class ImageMetaKey: FILENAME_OR_OBJ = "filename_or_obj" PATCH_INDEX = "patch_index" - - -def convert_data_type(data, output_type: type) -> Tuple[TransformTypes.Images, type]: - """Convert to torch.Tensor/np.ndarray.""" - orig_type = type(data) - assert orig_type in (torch.Tensor, np.ndarray) - if orig_type is np.ndarray and output_type is torch.Tensor: - data = torch.Tensor(data) - elif orig_type is torch.Tensor and output_type is np.ndarray: - data = data.detach().cpu().numpy() - - return data, orig_type diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index bb3e37775e..dcb23e2092 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -9,13 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, Optional, Sequence import numpy as np import torch from monai.transforms import rescale_array from monai.utils import optional_import +from monai.utils.enums import DataObjects PIL, _ = optional_import("PIL") GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image") @@ -31,7 +32,7 @@ __all__ = ["make_animated_gif_summary", "add_animated_gif", "add_animated_gif_no_channels", "plot_2d_or_3d_image"] -def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale_factor: float = 1.0) -> Summary: +def _image3_animated_gif(tag: str, image: DataObjects.Images, scale_factor: float = 1.0) -> Summary: """Function to actually create the animated gif. Args: @@ -60,7 +61,7 @@ def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale def make_animated_gif_summary( tag: str, - image: Union[np.ndarray, torch.Tensor], + image: DataObjects.Images, max_out: int = 3, animation_axes: Sequence[int] = (3,), image_axes: Sequence[int] = (1, 2), @@ -95,7 +96,7 @@ def make_animated_gif_summary( image = image[tuple(slicing)] for it_i in range(min(max_out, list(image.shape)[0])): - one_channel_img: Union[torch.Tensor, np.ndarray] = ( + one_channel_img: DataObjects.Images = ( image[it_i, :, :, :].squeeze(dim=0) if isinstance(image, torch.Tensor) else image[it_i, :, :, :] ) summary_op = _image3_animated_gif(tag + suffix.format(it_i), one_channel_img, scale_factor) @@ -105,7 +106,7 @@ def make_animated_gif_summary( def add_animated_gif( writer: SummaryWriter, tag: str, - image_tensor: Union[np.ndarray, torch.Tensor], + image_tensor: DataObjects.Images, max_out: int, scale_factor: float, global_step: Optional[int] = None, @@ -132,7 +133,7 @@ def add_animated_gif( def add_animated_gif_no_channels( writer: SummaryWriter, tag: str, - image_tensor: Union[np.ndarray, torch.Tensor], + image_tensor: DataObjects.Images, max_out: int, scale_factor: float, global_step: Optional[int] = None, @@ -159,7 +160,7 @@ def add_animated_gif_no_channels( def plot_2d_or_3d_image( - data: Union[torch.Tensor, np.ndarray], + data: DataObjects.Images, step: int, writer: SummaryWriter, index: int = 0, diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index 48aac7ec56..ac9b62f27b 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -18,7 +18,7 @@ from nibabel.processing import resample_to_output from parameterized import parameterized -from monai.transforms import AddChanneld, LoadImaged, Orientationd, Spacingd +from monai.transforms import AddChanneld, LoadImaged, Orientationd, Spacingd, convert_data_type FILES = tuple( os.path.join(os.path.dirname(__file__), "testing_data", filename) @@ -36,7 +36,8 @@ def test_load_spacingd(self, filename): res_dict = Spacingd(keys="image", pixdim=(1, 0.2, 1), diagonal=True, padding_mode="zeros")(data_dict) t1 = time.time() print(f"time monai: {t1 - t}") - anat = nibabel.Nifti1Image(data_dict["image"][0], data_dict["image_meta_dict"]["original_affine"]) + im_np, _ = convert_data_type(data_dict["image"][0], np.ndarray) + anat = nibabel.Nifti1Image(im_np, data_dict["image_meta_dict"]["original_affine"]) ref = resample_to_output(anat, (1, 0.2, 1), order=1) t2 = time.time() print(f"time scipy: {t2 - t1}") @@ -58,7 +59,8 @@ def test_load_spacingd_rotate(self, filename): res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=True, padding_mode="zeros")(data_dict) t1 = time.time() print(f"time monai: {t1 - t}") - anat = nibabel.Nifti1Image(data_dict["image"][0], data_dict["image_meta_dict"]["original_affine"]) + im_np, _ = convert_data_type(data_dict["image"][0], np.ndarray) + anat = nibabel.Nifti1Image(im_np, data_dict["image_meta_dict"]["original_affine"]) ref = resample_to_output(anat, (1, 2, 3), order=1) t2 = time.time() print(f"time scipy: {t2 - t1}") diff --git a/tests/test_remove_repeated_channel.py b/tests/test_remove_repeated_channel.py index 070e0e2b8d..6224d8e88d 100644 --- a/tests/test_remove_repeated_channel.py +++ b/tests/test_remove_repeated_channel.py @@ -11,12 +11,12 @@ import unittest -import numpy as np +import torch from parameterized import parameterized from monai.transforms import RemoveRepeatedChannel -TEST_CASE_1 = [{"repeats": 2}, np.array([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)] +TEST_CASE_1 = [{"repeats": 2}, torch.Tensor([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)] class TestRemoveRepeatedChannel(unittest.TestCase): From 5680d9e6340ef2dc2bbcc9231fb009edf2113226 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 28 May 2021 17:10:14 +0100 Subject: [PATCH 009/176] to_numpy, to_tensor, transpose Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 2 +- monai/transforms/utility/array.py | 37 ++++++++++------- tests/test_to_numpy.py | 2 +- tests/test_to_tensor.py | 67 +++++++++++++++++++++++++++++++ tests/test_transpose.py | 8 ++-- 5 files changed, 96 insertions(+), 20 deletions(-) create mode 100644 tests/test_to_tensor.py diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index c3c42fff13..5bb40eba6c 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -41,7 +41,7 @@ def convert_data_type(data, output_type: type) -> Tuple[DataObjects.Images, type orig_type = type(data) assert orig_type in (torch.Tensor, np.ndarray), f"Expected types {(torch.Tensor, np.ndarray)}, got {orig_type}" if orig_type is np.ndarray and output_type is torch.Tensor: - data = torch.Tensor(data) + data = torch.Tensor(np.ascontiguousarray(data)) elif orig_type is torch.Tensor and output_type is np.ndarray: data = data.detach().cpu().numpy() diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 83cbab7a9c..3fe7a751c1 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -22,7 +22,7 @@ import torch from monai.config import DtypeLike -from monai.transforms.transform import Randomizable, Transform, convert_data_type +from monai.transforms.transform import NumpyTransform, Randomizable, Transform, convert_data_type from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices from monai.utils import ensure_tuple, min_version, optional_import from monai.utils.enums import DataObjects @@ -308,13 +308,16 @@ class ToTensor(Transform): Converts the input image to a tensor without applying any other transformations. """ - def __call__(self, img) -> torch.Tensor: + def __call__(self, img: Union[DataObjects.Images, Sequence]) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ - if isinstance(img, torch.Tensor): - return img.contiguous() - return torch.as_tensor(np.ascontiguousarray(img)) + if isinstance(img, Sequence): + img = torch.Tensor(img) + else: + img, _ = convert_data_type(img, torch.Tensor) + img = img.contiguous() + return img class ToNumpy(Transform): @@ -322,12 +325,16 @@ class ToNumpy(Transform): Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor. """ - def __call__(self, img) -> np.ndarray: + def __call__(self, img: Union[DataObjects.Images, Sequence]) -> np.ndarray: """ Apply the transform to `img` and make it contiguous. """ - img, _ = convert_data_type(img, np.ndarray) - return np.ascontiguousarray(img) + if isinstance(img, Sequence): + img = np.array(img) + else: + img, _ = convert_data_type(img, np.ndarray) + img = np.ascontiguousarray(img) + return img class ToCupy(Transform): @@ -335,7 +342,7 @@ class ToCupy(Transform): Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor. """ - def __call__(self, img): + def __call__(self, img: Union[DataObjects.Images, Sequence]): """ Apply the transform to `img` and make it contiguous. """ @@ -355,12 +362,10 @@ def __call__(self, img): """ if isinstance(img, PILImageImage): return img - if isinstance(img, torch.Tensor): - img = img.detach().cpu().numpy() - return pil_image_fromarray(img) + return pil_image_fromarray(ToNumpy()(img)) -class Transpose(Transform): +class Transpose(NumpyTransform): """ Transposes the input image based on the given `indices` dimension ordering. """ @@ -368,11 +373,13 @@ class Transpose(Transform): def __init__(self, indices: Optional[Sequence[int]]) -> None: self.indices = None if indices is None else tuple(indices) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - return img.transpose(self.indices) # type: ignore + img, orig_type = self.pre_conv_data(img) + img = img.transpose(self.indices) # type: ignore + return self.post_convert_data(img, orig_type) class SqueezeDim(Transform): diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index 6e112e6be8..0e0d6d1caf 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -23,7 +23,7 @@ class TestToNumpy(unittest.TestCase): @skipUnless(has_cp, "CuPy is required.") - def test_cumpy_input(self): + def test_cupy_input(self): test_data = cp.array([[1, 2], [3, 4]]) test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py new file mode 100644 index 0000000000..6a36786914 --- /dev/null +++ b/tests/test_to_tensor.py @@ -0,0 +1,67 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import skipUnless + +import numpy as np +import torch + +from monai.transforms import ToTensor +from monai.utils import optional_import + +cp, has_cp = optional_import("cupy") + + +class TestToTensor(unittest.TestCase): + @skipUnless(has_cp, "CuPy is required.") + def test_cupy_input(self): + test_data = cp.array([[1, 2], [3, 4]]) + test_data = cp.rot90(test_data) + self.assertFalse(test_data.flags["C_CONTIGUOUS"]) + result = ToTensor()(test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + self.assertTrue(result.is_contiguous()) + torch.testing.assert_allclose(result, test_data.get()) + + def test_numpy_input(self): + test_data = np.array([[1, 2], [3, 4]]) + test_data = np.rot90(test_data) + self.assertFalse(test_data.flags["C_CONTIGUOUS"]) + result = ToTensor()(test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + self.assertTrue(result.is_contiguous()) + torch.testing.assert_allclose(result, np.ascontiguousarray(test_data)) + + def test_tensor_input(self): + test_data = torch.tensor([[1, 2], [3, 4]]) + test_data = test_data.rot90() + self.assertFalse(test_data.is_contiguous()) + result = ToTensor()(test_data) + self.assertTrue(isinstance(result, torch.Tensor)) + self.assertTrue(result.is_contiguous()) + torch.testing.assert_allclose(result, test_data) + + def test_list_tuple(self): + test_data = [[1, 2], [3, 4]] + result = ToTensor()(test_data) + torch.testing.assert_allclose(result, test_data) + test_data = ((1, 2), (3, 4)) + result = ToTensor()(test_data) + torch.testing.assert_allclose(result, test_data) + + +if __name__ == "__main__": + # unittest.main() + a = TestToTensor() + a.test_list_tuple() + a.test_numpy_input() + a.test_tensor_input() diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 3b758b5aa2..acc4ce3e9b 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -13,6 +13,7 @@ import numpy as np from parameterized import parameterized +import torch from monai.transforms import Transpose @@ -31,9 +32,10 @@ class TestTranspose(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_transpose(self, im, indices): tr = Transpose(indices) - out1 = tr(im) - out2 = np.transpose(im, indices) - np.testing.assert_array_equal(out1, out2) + for array_class in (np.array, torch.Tensor): + out1 = tr(array_class(im)) + out2 = np.transpose(im, indices) + np.testing.assert_array_equal(out1, out2) if __name__ == "__main__": From d921834470b4638732328d45d1b8b63c67140062 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 28 May 2021 17:34:31 +0100 Subject: [PATCH 010/176] add todotransform Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 34 +++++++++++++++---------------- monai/transforms/transform.py | 18 ++++++++++++++++ 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index e6dc90d742..430d68be60 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -23,7 +23,7 @@ from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.croppad.array import CenterSpatialCrop -from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform +from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, ToDoTransform, TorchTransform, Transform from monai.transforms.utils import ( create_control_grid, create_grid, @@ -76,7 +76,7 @@ RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] -class Spacing(Transform): +class Spacing(ToDoTransform): """ Resample input image into the specified `pixdim`. """ @@ -208,7 +208,7 @@ def __call__( return output_data, affine, new_affine -class Orientation(Transform): +class Orientation(ToDoTransform): """ Change the input image's orientation into the specified based on `axcodes`. """ @@ -293,7 +293,7 @@ def __call__( return data_array, affine, new_affine -class Flip(Transform): +class Flip(TorchTransform): """ Reverses the order of elements along the given spatial axis. Preserves shape. Uses ``np.flip`` in practice. See numpy.flip for additional details: @@ -323,7 +323,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return self.post_convert_data(result, orig_type) -class Resize(Transform): +class Resize(ToDoTransform): """ Resize the input image to given spatial size (with scaling, not cropping/padding). Implemented using :py:class:`torch.nn.functional.interpolate`. @@ -392,7 +392,7 @@ def __call__( return np.asarray(resized) -class Rotate(Transform, ThreadUnsafe): +class Rotate(ToDoTransform, ThreadUnsafe): """ Rotates an input image by given angle using :py:class:`monai.networks.layers.AffineTransform`. @@ -502,7 +502,7 @@ def get_rotation_matrix(self) -> Optional[np.ndarray]: return self._rotation_matrix -class Zoom(Transform): +class Zoom(ToDoTransform): """ Zooms an ND image using :py:class:`torch.nn.functional.interpolate`. For details, please see https://pytorch.org/docs/stable/nn.functional.html#interpolate. @@ -590,7 +590,7 @@ def __call__( return zoomed[tuple(slice_vec)] -class Rotate90(Transform): +class Rotate90(ToDoTransform): """ Rotate an array by 90 degrees in the plane specified by `axes`. See np.rot90 for additional details: @@ -622,7 +622,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return result.astype(img.dtype) -class RandRotate90(RandomizableTransform): +class RandRotate90(ToDoTransform, RandomizableTransform): """ With probability `prob`, input arrays are rotated by 90 degrees in the plane specified by `spatial_axes`. @@ -659,7 +659,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return rotator(img) -class RandRotate(RandomizableTransform): +class RandRotate(ToDoTransform, RandomizableTransform): """ Randomly rotate the input arrays. @@ -763,7 +763,7 @@ def __call__( return np.array(rotator(img)) -class RandFlip(RandomizableTransform): +class RandFlip(TorchTransform, RandomizableTransform): """ Randomly flips the image along axes. Preserves shape. See numpy.flip for additional details. @@ -789,7 +789,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return self.flipper(img) -class RandAxisFlip(RandomizableTransform): +class RandAxisFlip(TorchTransform, RandomizableTransform): """ Randomly select a spatial axis and flip along it. See numpy.flip for additional details. @@ -820,7 +820,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return flipper(img) -class RandZoom(RandomizableTransform): +class RandZoom(ToDoTransform, RandomizableTransform): """ Randomly zooms input arrays with given probability within given zoom range. @@ -919,7 +919,7 @@ def __call__( ) -class AffineGrid(Transform): +class AffineGrid(ToDoTransform): """ Affine transforms on the coordinates. @@ -1015,7 +1015,7 @@ def __call__( return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()), affine -class RandAffineGrid(Randomizable, Transform): +class RandAffineGrid(Randomizable, ToDoTransform): """ Generate randomised affine grid. """ @@ -1112,7 +1112,7 @@ def get_transformation_matrix(self) -> Optional[DataObjects.Images]: return self.affine -class RandDeformGrid(Randomizable, Transform): +class RandDeformGrid(Randomizable, ToDoTransform): """ Generate random deformation grid. """ @@ -1162,7 +1162,7 @@ def __call__(self, spatial_size: Sequence[int]): return control_grid -class Resample(Transform): +class Resample(ToDoTransform): def __init__( self, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 5bb40eba6c..9475ac7159 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -236,6 +236,19 @@ def post_convert_data(data: DataObjects.Images, output_type: type) -> DataObject return data +class TorchOrNumpyTransform(Transform): + """Transforms that inherit from this class process the input the same regardless of whether + the input is torch or numpy. No conversions are needed.""" + pass + + +class TorchTransform(Transform): + """Most transforms use torch. If the transforms inherits from this class, then that is the case. + If the input is not torch, convert to torch and then convert back at the end.""" + + array_class = torch.Tensor + + class NumpyTransform(Transform): """Most transforms use torch. Transforms that inherit from this class, however, use numpy under the hood. This means that if the input image is `torch.Tensor`, it will be converted to numpy and then reverted @@ -244,6 +257,11 @@ class NumpyTransform(Transform): array_class = np.ndarray +class ToDoTransform(Transform): + """Transforms that inherit from this class are still to be updated. This is a temporary class.""" + pass + + class RandomizableTransform(Randomizable, Transform): """ An interface for handling random state locally, currently based on a class variable `R`, From 59a7ce7866f9596479f3ecdbb431f146cb3a31d2 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 30 May 2021 16:07:29 +0100 Subject: [PATCH 011/176] full ci/cd for feature branches Signed-off-by: Wenqi Li --- .github/workflows/cleanup.yml | 1 + .github/workflows/setupapp.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/cleanup.yml b/.github/workflows/cleanup.yml index f3d297286e..80d72d8066 100644 --- a/.github/workflows/cleanup.yml +++ b/.github/workflows/cleanup.yml @@ -4,6 +4,7 @@ on: workflow_run: workflows: - "build" + - "deploy" types: ["requested"] jobs: diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 5c011a3af1..cb9bdd6512 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -7,6 +7,7 @@ on: - dev - main - releasing/* + - feature/* jobs: # caching of these jobs: From 4662522ed580c0f3a1a930d97e23605894665b88 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 1 Jun 2021 14:01:36 +0100 Subject: [PATCH 012/176] more transforms Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 37 +++++++++++-------- monai/transforms/post/array.py | 28 ++++++++------ monai/transforms/post/dictionary.py | 4 +- monai/transforms/spatial/array.py | 11 +++++- monai/transforms/transform.py | 2 + monai/transforms/utility/array.py | 57 +++++++++++++++++++---------- tests/test_squeezedim.py | 26 ++++++------- tests/test_transpose.py | 2 +- 8 files changed, 101 insertions(+), 66 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 6688846637..716c5f92e2 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -22,7 +22,7 @@ from monai.config import DtypeLike from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter -from monai.transforms.transform import RandomizableTransform, Transform +from monai.transforms.transform import RandomizableTransform, TorchOrNumpyTransform, Transform from monai.transforms.utils import rescale_array from monai.utils import ( PT_BEFORE_1_7, @@ -171,7 +171,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img -class ShiftIntensity(Transform): +class ShiftIntensity(TorchOrNumpyTransform): """ Shift intensity uniformly for the entire image with specified `offset`. @@ -182,14 +182,17 @@ class ShiftIntensity(Transform): def __init__(self, offset: float) -> None: self.offset = offset - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - return np.asarray((img + self.offset), dtype=img.dtype) + out = img + self.offset + if isinstance(out, torch.Tensor): + return out.type(img.dtype) # type: ignore + return out.astype(img.dtype) # type: ignore -class RandShiftIntensity(RandomizableTransform): +class RandShiftIntensity(RandomizableTransform, TorchOrNumpyTransform): """ Randomly shift intensity with randomly picked offset. """ @@ -214,7 +217,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1]) super().randomize(None) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -585,7 +588,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: ) -class ScaleIntensityRange(Transform): +class ScaleIntensityRange(TorchOrNumpyTransform): """ Apply specific intensity scaling to the whole numpy array. Scaling from [a_min, a_max] to [b_min, b_max] with clip option. @@ -605,7 +608,7 @@ def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: self.b_max = b_max self.clip = clip - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -616,11 +619,12 @@ def __call__(self, img: np.ndarray): img = (img - self.a_min) / (self.a_max - self.a_min) img = img * (self.b_max - self.b_min) + self.b_min if self.clip: - img = np.asarray(np.clip(img, self.b_min, self.b_max)) + mod = torch if isinstance(img, torch.Tensor) else np + img = mod.clip(img, self.b_min, self.b_max) # type: ignore return img -class AdjustContrast(Transform): +class AdjustContrast(TorchOrNumpyTransform): """ Changes image intensity by gamma. Each pixel/voxel intensity is updated as:: @@ -635,17 +639,17 @@ def __init__(self, gamma: float) -> None: raise AssertionError("gamma must be a float or int number.") self.gamma = gamma - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ epsilon = 1e-7 img_min = img.min() img_range = img.max() - img_min - return np.power(((img - img_min) / float(img_range + epsilon)), self.gamma) * img_range + img_min + return ((img - img_min) / float(img_range + epsilon)) ** self.gamma * img_range + img_min -class RandAdjustContrast(RandomizableTransform): +class RandAdjustContrast(TorchOrNumpyTransform, RandomizableTransform): """ Randomly changes image intensity by gamma. Each pixel/voxel intensity is updated as:: @@ -690,7 +694,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return adjuster(img) -class ScaleIntensityRangePercentiles(Transform): +class ScaleIntensityRangePercentiles(TorchOrNumpyTransform): """ Apply range scaling to a numpy array based on the intensity distribution of the input. @@ -759,7 +763,7 @@ def __init__( self.clip = clip self.relative = relative - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -776,7 +780,8 @@ def __call__(self, img: np.ndarray): img = scalar(img) if self.clip: - img = np.asarray(np.clip(img, self.b_min, self.b_max)) + mod = torch if isinstance(img, torch.Tensor) else np + img = mod.clip(img, self.b_min, self.b_max) # type: ignore return img diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 0a7cb8c716..00f2f370a8 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -22,7 +22,7 @@ from monai.networks import one_hot from monai.networks.layers import GaussianFilter -from monai.transforms.transform import Transform +from monai.transforms.transform import TorchTransform, Transform from monai.transforms.utils import get_largest_connected_component_mask from monai.utils import ensure_tuple from monai.utils.enums import DataObjects @@ -38,7 +38,7 @@ ] -class Activations(Transform): +class Activations(TorchTransform): """ Add activation operations to the model output, typically `Sigmoid` or `Softmax`. @@ -64,11 +64,11 @@ def __init__(self, sigmoid: bool = False, softmax: bool = False, other: Optional def __call__( self, - img: torch.Tensor, + img: DataObjects.Images, sigmoid: Optional[bool] = None, softmax: Optional[bool] = None, other: Optional[Callable] = None, - ) -> torch.Tensor: + ) -> DataObjects.Images: """ Args: sigmoid: whether to execute sigmoid function on model output before transform. @@ -89,8 +89,10 @@ def __call__( if other is not None and not callable(other): raise TypeError(f"other must be None or callable but is {type(other).__name__}.") + img, orig_type = self.pre_conv_data(img) + # convert to float as activation must operate on float tensor - img = img.float() + img = img.float() # type: ignore if sigmoid or self.sigmoid: img = torch.sigmoid(img) if softmax or self.softmax: @@ -103,7 +105,7 @@ def __call__( if act_func is not None: img = act_func(img) - return img + return self.post_convert_data(img, orig_type) class AsDiscrete(Transform): @@ -181,7 +183,7 @@ def __call__( return img.float() -class KeepLargestConnectedComponent(Transform): +class KeepLargestConnectedComponent(TorchTransform): """ Keeps only the largest connected component in the image. This transform can be used as a post-processing step to clean up over-segment areas in model output. @@ -247,7 +249,7 @@ def __init__( self.independent = independent self.connectivity = connectivity - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: shape must be (batch_size, C, spatial_dim1[, spatial_dim2, ...]). @@ -255,10 +257,12 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: Returns: A PyTorch Tensor with shape (batch_size, C, spatial_dim1[, spatial_dim2, ...]). """ + img, orig_type = self.pre_conv_data(img) + channel_dim = 1 if img.shape[channel_dim] == 1: - img = torch.squeeze(img, dim=channel_dim) + img = torch.squeeze(img, dim=channel_dim) # type: ignore if self.independent: for i in self.applied_labels: @@ -286,10 +290,10 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: background_mask = torch.unsqueeze(foreground != mask, dim=channel_dim) background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=channel_dim) applied_img[background_mask] = 0 - img[:, self.applied_labels, ...] = applied_img.type(img.type()) - output = img + img[:, self.applied_labels, ...] = applied_img.type(img.type()) # type: ignore + output = img # type: ignore - return output + return self.post_convert_data(output, orig_type) class LabelToContour(Transform): diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 883ac8c02f..a2890f71ee 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -112,7 +112,7 @@ def __init__( self.other = ensure_tuple_rep(other, len(self.keys)) self.converter = Activations() - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Mapping: d = dict(data) for key, sigmoid, softmax, other in self.key_iterator(d, self.sigmoid, self.softmax, self.other): d[key] = self.converter(d[key], sigmoid, softmax, other) @@ -207,7 +207,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = KeepLargestConnectedComponent(applied_labels, independent, connectivity) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Mapping: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 1ed06ec0e7..c6e7f12332 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -23,7 +23,14 @@ from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.croppad.array import CenterSpatialCrop -from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, ToDoTransform, TorchTransform, Transform +from monai.transforms.transform import ( + Randomizable, + RandomizableTransform, + ThreadUnsafe, + ToDoTransform, + TorchTransform, + Transform, +) from monai.transforms.utils import ( create_control_grid, create_grid, @@ -1750,7 +1757,7 @@ def __call__( return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) -class AddCoordinateChannels(Transform): +class AddCoordinateChannels(ToDoTransform): """ Appends additional channels encoding coordinates of the input. Useful when e.g. training using patch-based sampling, to allow feeding of the patch's location into the network. diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 9475ac7159..9d0e81331f 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -239,6 +239,7 @@ def post_convert_data(data: DataObjects.Images, output_type: type) -> DataObject class TorchOrNumpyTransform(Transform): """Transforms that inherit from this class process the input the same regardless of whether the input is torch or numpy. No conversions are needed.""" + pass @@ -259,6 +260,7 @@ class NumpyTransform(Transform): class ToDoTransform(Transform): """Transforms that inherit from this class are still to be updated. This is a temporary class.""" + pass diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3fe7a751c1..cad692fcf0 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -22,7 +22,14 @@ import torch from monai.config import DtypeLike -from monai.transforms.transform import NumpyTransform, Randomizable, Transform, convert_data_type +from monai.transforms.transform import ( + NumpyTransform, + Randomizable, + TorchOrNumpyTransform, + TorchTransform, + Transform, + convert_data_type, +) from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices from monai.utils import ensure_tuple, min_version, optional_import from monai.utils.enums import DataObjects @@ -59,7 +66,7 @@ ] -class Identity(Transform): +class Identity(TorchOrNumpyTransform): """ Convert the input to an np.ndarray, if input data is np.ndarray or subclasses, return unchanged data. As the output value is same as input, it can be used as a testing tool to verify the transform chain, @@ -67,14 +74,14 @@ class Identity(Transform): """ - def __call__(self, img: DataObjects.Images) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - return np.asanyarray(img) + return img -class AsChannelFirst(Transform): +class AsChannelFirst(TorchTransform): """ Change the channel dimension of the image to the first dimension. @@ -104,7 +111,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return self.post_convert_data(img, orig_type) -class AsChannelLast(Transform): +class AsChannelLast(TorchTransform): """ Change the channel dimension of the image to the last dimension. @@ -133,7 +140,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return self.post_convert_data(img, orig_type) -class AddChannel(Transform): +class AddChannel(TorchOrNumpyTransform): """ Adds a 1-length channel dimension to the input image. @@ -154,7 +161,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img[None] -class EnsureChannelFirst(Transform): +class EnsureChannelFirst(TorchTransform): """ Automatically adjust or add the channel dimension of input data to ensure `channel_first` shape. It extracts the `original_channel_dim` info from provided meta_data dictionary. @@ -231,7 +238,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return self.post_convert_data(img, orig_type) -class SplitChannel(Transform): +class SplitChannel(TorchOrNumpyTransform): """ Split Numpy array or PyTorch Tensor data according to the channel dim. It can help applying different following transforms to different channels. @@ -303,7 +310,7 @@ def __call__( raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.") -class ToTensor(Transform): +class ToTensor(TorchTransform): """ Converts the input image to a tensor without applying any other transformations. """ @@ -316,11 +323,11 @@ def __call__(self, img: Union[DataObjects.Images, Sequence]) -> torch.Tensor: img = torch.Tensor(img) else: img, _ = convert_data_type(img, torch.Tensor) - img = img.contiguous() + img = img.contiguous() # type: ignore return img -class ToNumpy(Transform): +class ToNumpy(NumpyTransform): """ Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor. """ @@ -382,7 +389,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return self.post_convert_data(img, orig_type) -class SqueezeDim(Transform): +class SqueezeDim(TorchOrNumpyTransform): """ Squeeze a unitary dimension. """ @@ -406,7 +413,12 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: Args: img: numpy arrays with required dimension `dim` removed """ - return img.squeeze(self.dim) # type: ignore + if self.dim is None: + return img.squeeze() + # for pytorch/numpy unification + if img.shape[self.dim] != 1: + raise ValueError("Can only squeeze singleton dimension") + return img.squeeze(self.dim) class DataStats(Transform): @@ -500,7 +512,7 @@ def __call__( return img -class SimulateDelay(Transform): +class SimulateDelay(TorchOrNumpyTransform): """ This is a pass through transform to be used for testing purposes. It allows adding fake behaviors that are useful for testing purposes to simulate @@ -580,7 +592,7 @@ def __call__(self, img: DataObjects.Images, func: Optional[Callable] = None): raise ValueError("Incompatible values: func=None and self.func=None.") -class LabelToMask(Transform): +class LabelToMask(NumpyTransform): """ Convert labels to mask for other tasks. A typical usage is to convert segmentation labels to mask data to pre-process images and then feed the images into classification network. @@ -608,8 +620,11 @@ def __init__( # pytype: disable=annotation-type-mismatch self.merge_channels = merge_channels def __call__( - self, img: np.ndarray, select_labels: Optional[Union[Sequence[int], int]] = None, merge_channels: bool = False - ): + self, + img: DataObjects.Images, + select_labels: Optional[Union[Sequence[int], int]] = None, + merge_channels: bool = False, + ) -> DataObjects.Images: """ Args: select_labels: labels to generate mask from. for 1 channel label, the `select_labels` @@ -618,6 +633,8 @@ def __call__( merge_channels: whether to use `np.any()` to merge the result on channel dim. if yes, will return a single channel mask with binary data. """ + img, orig_type = self.pre_conv_data(img) + if select_labels is None: select_labels = self.select_labels else: @@ -628,7 +645,9 @@ def __call__( else: data = np.where(np.in1d(img, select_labels), True, False).reshape(img.shape) - return np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data + out = np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data + + return self.post_convert_data(out, orig_type) # type: ignore class FgBgToIndices(Transform): diff --git a/tests/test_squeezedim.py b/tests/test_squeezedim.py index 01ea489320..d9ab217b9e 100644 --- a/tests/test_squeezedim.py +++ b/tests/test_squeezedim.py @@ -17,29 +17,27 @@ from monai.transforms import SqueezeDim -TEST_CASE_1 = [{"dim": None}, np.random.rand(1, 2, 1, 3), (2, 3)] +TESTS, TESTS_FAIL = [], [] +for r in (np.random.rand, torch.rand): + TESTS.append([{"dim": None}, r(1, 2, 1, 3), (2, 3)]) + TESTS.append([{"dim": 2}, r(1, 2, 1, 8, 16), (1, 2, 8, 16)]) + TESTS.append([{"dim": -1}, r(1, 1, 16, 8, 1), (1, 1, 16, 8)]) + TESTS.append([{}, r(1, 2, 1, 3), (2, 1, 3)]) -TEST_CASE_2 = [{"dim": 2}, np.random.rand(1, 2, 1, 8, 16), (1, 2, 8, 16)] - -TEST_CASE_3 = [{"dim": -1}, np.random.rand(1, 1, 16, 8, 1), (1, 1, 16, 8)] - -TEST_CASE_4 = [{}, np.random.rand(1, 2, 1, 3), (2, 1, 3)] - -TEST_CASE_4_PT = [{}, torch.rand(1, 2, 1, 3), (2, 1, 3)] - -TEST_CASE_5 = [ValueError, {"dim": -2}, np.random.rand(1, 1, 16, 8, 1)] - -TEST_CASE_6 = [TypeError, {"dim": 0.5}, np.random.rand(1, 1, 16, 8, 1)] + TESTS_FAIL.append([ValueError, {"dim": -2}, r(1, 1, 16, 8, 1)]) + TESTS_FAIL.append([TypeError, {"dim": 0.5}, r(1, 1, 16, 8, 1)]) class TestSqueezeDim(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_4_PT]) + @parameterized.expand(TESTS) def test_shape(self, input_param, test_data, expected_shape): + result = SqueezeDim(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS_FAIL) def test_invalid_inputs(self, exception, input_param, test_data): + with self.assertRaises(exception): SqueezeDim(**input_param)(test_data) diff --git a/tests/test_transpose.py b/tests/test_transpose.py index acc4ce3e9b..fd505d8ba9 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -12,8 +12,8 @@ import unittest import numpy as np -from parameterized import parameterized import torch +from parameterized import parameterized from monai.transforms import Transpose From 9a6901c9554de67ca9648043818994f7bd30205b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 1 Jun 2021 17:55:51 +0100 Subject: [PATCH 013/176] SpatialCrop, CenterSpatialCrop, AsDiscrete, RepeatChannel, RemoveRepeatedChannel Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 20 ++++++++++---------- monai/transforms/post/array.py | 17 ++++++++++------- monai/transforms/post/dictionary.py | 2 +- monai/transforms/utility/array.py | 23 ++++++++++------------- tests/test_remove_repeated_channel.py | 7 +++++-- tests/test_repeat_channel.py | 7 +++++-- tests/test_transpose.py | 21 ++++++++------------- 7 files changed, 49 insertions(+), 48 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 91443eaa91..fbb4741535 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -21,7 +21,7 @@ from monai.config import IndexSelection from monai.data.utils import get_random_patch, get_valid_patch_size -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import Randomizable, TorchOrNumpyTransform, Transform from monai.transforms.utils import ( compute_divisible_spatial_size, generate_pos_neg_label_crop_centers, @@ -211,7 +211,7 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N return spatial_pad(img) -class SpatialCrop(Transform): +class SpatialCrop(TorchOrNumpyTransform): """ General purpose cropper to produce sub-volume region of interest (ROI). It can support to crop ND spatial (channel-first) data. @@ -259,7 +259,7 @@ def __init__( # convert to slices self.slices = [slice(s, e) for s, e in zip(roi_start_np, roi_end_np)] - def __call__(self, img: DataObjects.Images): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -269,7 +269,7 @@ def __call__(self, img: DataObjects.Images): return img[tuple(slices)] -class CenterSpatialCrop(Transform): +class CenterSpatialCrop(TorchOrNumpyTransform): """ Crop at the center of image with specified ROI size. @@ -281,7 +281,7 @@ class CenterSpatialCrop(Transform): def __init__(self, roi_size: Union[Sequence[int], int]) -> None: self.roi_size = roi_size - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -556,7 +556,7 @@ def crop_pad(self, img: np.ndarray, box_start: np.ndarray, box_end: np.ndarray): Crop and pad based on the bounding box. """ - cropped = SpatialCrop(roi_start=box_start, roi_end=box_end)(img) + cropped: np.ndarray = SpatialCrop(roi_start=box_start, roi_end=box_end)(img) # type: ignore pad_to_start = np.maximum(-box_start, 0) pad_to_end = np.maximum(box_end - np.asarray(img.shape[1:]), 0) pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) @@ -620,10 +620,10 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> raise ValueError(f"image and weight map spatial shape mismatch: {img.shape[1:]} vs {weight_map.shape[1:]}.") self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:]) - results = [] + results: List[np.ndarray] = [] for center in self.centers: cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) - results.append(cropper(img)) + results.append(cropper(img)) # type: ignore return results @@ -752,7 +752,7 @@ def __call__( if self.centers is not None: for center in self.centers: cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - results.append(cropper(img)) + results.append(cropper(img)) # type: ignore return results @@ -793,7 +793,7 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N If None, defaults to the ``mode`` in construction. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ - return self.padder(self.cropper(img), mode=mode) + return self.padder(self.cropper(img), mode=mode) # type: ignore class BoundingRect(Transform): diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 00f2f370a8..05d9f3717a 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -108,7 +108,7 @@ def __call__( return self.post_convert_data(img, orig_type) -class AsDiscrete(Transform): +class AsDiscrete(TorchTransform): """ Execute after model forward to transform model output to discrete values. It can complete below operations: @@ -147,13 +147,13 @@ def __init__( def __call__( self, - img: torch.Tensor, + img: DataObjects.Images, argmax: Optional[bool] = None, to_onehot: Optional[bool] = None, n_classes: Optional[int] = None, threshold_values: Optional[bool] = None, logit_thresh: Optional[float] = None, - ) -> torch.Tensor: + ) -> DataObjects.Images: """ Args: argmax: whether to execute argmax function on input data before transform. @@ -168,19 +168,22 @@ def __call__( Defaults to ``self.logit_thresh``. """ + img_t: torch.Tensor + img_t, orig_type = self.pre_conv_data(img) # type: ignore + if argmax or self.argmax: - img = torch.argmax(img, dim=1, keepdim=True) + img_t = torch.argmax(img_t, dim=1, keepdim=True) if to_onehot or self.to_onehot: _nclasses = self.n_classes if n_classes is None else n_classes if not isinstance(_nclasses, int): raise AssertionError("One of self.n_classes or n_classes must be an integer") - img = one_hot(img, _nclasses) + img_t = one_hot(img_t, _nclasses) if threshold_values or self.threshold_values: - img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) + img_t = img_t >= (logit_thresh or self.logit_thresh) - return img.float() + return self.post_convert_data(img_t.float(), orig_type) class KeepLargestConnectedComponent(TorchTransform): diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index a2890f71ee..3e0308d156 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -159,7 +159,7 @@ def __init__( self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys)) self.converter = AsDiscrete() - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh in self.key_iterator( d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index cad692fcf0..b97e93c8f9 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -186,7 +186,7 @@ def __call__(self, img: DataObjects.Images, meta_dict: Optional[Dict] = None) -> return AsChannelFirst(channel_dim=channel_dim)(img) -class RepeatChannel(Transform): +class RepeatChannel(TorchOrNumpyTransform): """ Repeat channel data to construct expected input shape for models. The `repeats` count includes the origin data, for example: @@ -205,12 +205,11 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ - img, orig_type = self.pre_conv_data(img) - img = torch.repeat_interleave(img, self.repeats, 0) # type: ignore - return self.post_convert_data(img, orig_type) + repeeat_fn = torch.repeat_interleave if isinstance(img, torch.Tensor) else np.repeat + return repeeat_fn(img, self.repeats, 0) # type: ignore -class RemoveRepeatedChannel(Transform): +class RemoveRepeatedChannel(TorchOrNumpyTransform): """ RemoveRepeatedChannel data to undo RepeatChannel The `repeats` count specifies the deletion of the origin data, for example: @@ -230,12 +229,10 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is a "channel-first" array. """ - if np.shape(img)[0] < 2: + if img.shape[0] < 2: raise AssertionError("Image must have more than one channel") - img, orig_type = self.pre_conv_data(img) - img = torch.Tensor(img[:: self.repeats, :]) # type: ignore - return self.post_convert_data(img, orig_type) + return img[:: self.repeats, :] class SplitChannel(TorchOrNumpyTransform): @@ -277,7 +274,7 @@ def __call__(self, img: DataObjects.Images) -> List[DataObjects.Images]: return outputs -class CastToType(Transform): +class CastToType(TorchOrNumpyTransform): """ Cast the Numpy data to specified numpy data type, or cast the PyTorch Tensor to specified PyTorch data type. @@ -304,9 +301,9 @@ def __call__( """ if isinstance(img, np.ndarray): - return img.astype(self.dtype if dtype is None else dtype) # type: ignore + return img.astype(dtype or self.dtype) # type: ignore if isinstance(img, torch.Tensor): - return torch.as_tensor(img, dtype=self.dtype if dtype is None else dtype) + return img.to(dtype=dtype or self.dtype) # type: ignore raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.") @@ -421,7 +418,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img.squeeze(self.dim) -class DataStats(Transform): +class DataStats(TorchOrNumpyTransform): """ Utility transform to show the statistics of data for debug or analysis. It can be inserted into any place of a transform chain and check results of previous transforms. diff --git a/tests/test_remove_repeated_channel.py b/tests/test_remove_repeated_channel.py index 6224d8e88d..ebbe6c730c 100644 --- a/tests/test_remove_repeated_channel.py +++ b/tests/test_remove_repeated_channel.py @@ -11,16 +11,19 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.transforms import RemoveRepeatedChannel -TEST_CASE_1 = [{"repeats": 2}, torch.Tensor([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)] +TEST_CASES = [] +for q in (torch.Tensor, np.array): + TEST_CASES.append([{"repeats": 2}, q([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)]) # type: ignore class TestRemoveRepeatedChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_shape): result = RemoveRepeatedChannel(**input_param)(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_repeat_channel.py b/tests/test_repeat_channel.py index 643ebc64de..d19abcc8b4 100644 --- a/tests/test_repeat_channel.py +++ b/tests/test_repeat_channel.py @@ -12,15 +12,18 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RepeatChannel -TEST_CASE_1 = [{"repeats": 3}, np.array([[[0, 1], [1, 2]]]), (3, 2, 2)] +TEST_CASES = [] +for q in (torch.Tensor, np.array): + TEST_CASES.append([{"repeats": 3}, q([[[0, 1], [1, 2]]]), (3, 2, 2)]) # type: ignore class TestRepeatChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_data, expected_shape): result = RepeatChannel(**input_param)(input_data) self.assertEqual(result.shape, expected_shape) diff --git a/tests/test_transpose.py b/tests/test_transpose.py index fd505d8ba9..9ba9eeb4bc 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -17,25 +17,20 @@ from monai.transforms import Transpose -TEST_CASE_0 = [ - np.arange(5 * 4).reshape(5, 4), - None, -] -TEST_CASE_1 = [ - np.arange(5 * 4 * 3).reshape(5, 4, 3), - [2, 0, 1], -] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1] +TEST_CASES = [] +for q in (np.arange, torch.Tensor): + TEST_CASES.append([q(5 * 4).reshape(5, 4), None]) + TEST_CASES.append([q(5 * 4 * 3).reshape(5, 4, 3), [2, 0, 1]]) class TestTranspose(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_transpose(self, im, indices): tr = Transpose(indices) - for array_class in (np.array, torch.Tensor): - out1 = tr(array_class(im)) - out2 = np.transpose(im, indices) - np.testing.assert_array_equal(out1, out2) + + out1 = tr(im) + out2 = np.transpose(im, indices) + np.testing.assert_array_equal(out1, out2) if __name__ == "__main__": From 494db1df1ffca4f09cede24de75c69dd8afc47a6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 1 Jun 2021 18:00:10 +0100 Subject: [PATCH 014/176] MaskIntensity Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 23 +++++++-------- tests/test_mask_intensity.py | 45 +++++++++++++++++------------ 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 716c5f92e2..b9534d5bb9 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -22,7 +22,7 @@ from monai.config import DtypeLike from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter -from monai.transforms.transform import RandomizableTransform, TorchOrNumpyTransform, Transform +from monai.transforms.transform import RandomizableTransform, TorchOrNumpyTransform, Transform, convert_data_type from monai.transforms.utils import rescale_array from monai.utils import ( PT_BEFORE_1_7, @@ -786,7 +786,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img -class MaskIntensity(Transform): +class MaskIntensity(TorchOrNumpyTransform): """ Mask the intensity values of input image with the specified mask data. Mask data must have the same spatial size as the input image, and all @@ -801,10 +801,10 @@ class MaskIntensity(Transform): """ - def __init__(self, mask_data: Optional[np.ndarray]) -> None: + def __init__(self, mask_data: Optional[DataObjects.Images]) -> None: self.mask_data = mask_data - def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> np.ndarray: + def __call__(self, img: DataObjects.Images, mask_data: Optional[DataObjects.Images] = None) -> DataObjects.Images: """ Args: mask_data: if mask data is single channel, apply to every channel @@ -817,21 +817,20 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n - ValueError: When ``mask_data`` and ``img`` channels differ and ``mask_data`` is not single channel. """ - if self.mask_data is None and mask_data is None: - raise ValueError("Unknown mask_data.") - mask_data_ = np.array([[1]]) - if self.mask_data is not None and mask_data is None: - mask_data_ = self.mask_data > 0 + if mask_data is not None: mask_data_ = mask_data > 0 - mask_data_ = np.asarray(mask_data_) + elif self.mask_data is not None: + mask_data_ = self.mask_data > 0 + else: + raise ValueError("Unknown mask_data.") if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]: raise ValueError( "When mask_data is not single channel, mask_data channels must match img, " f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}." ) - - return np.asarray(img * mask_data_) + mask_data_, _ = convert_data_type(mask_data_, type(img)) + return img * mask_data_ class SavitzkyGolaySmooth(Transform): diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py index 3131abe8bf..97d2ec98b9 100644 --- a/tests/test_mask_intensity.py +++ b/tests/test_mask_intensity.py @@ -12,31 +12,40 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import MaskIntensity -TEST_CASE_1 = [ - {"mask_data": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), -] - -TEST_CASE_2 = [ - {"mask_data": np.array([[[0, 0, 0], [0, 5, 0], [0, 0, 0]]])}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), -] - -TEST_CASE_3 = [ - {"mask_data": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), -] +TEST_CASES = [] +for p in (torch.Tensor, np.array): + for q in (torch.Tensor, np.array): + for r in (torch.Tensor, np.array): + TEST_CASES.append( + [ + {"mask_data": p([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])}, # type: ignore + q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore + r([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), # type: ignore + ] + ) + TEST_CASES.append( + [ + {"mask_data": p([[[0, 0, 0], [0, 5, 0], [0, 0, 0]]])}, # type: ignore + q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore + r([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), # type: ignore + ] + ) + TEST_CASES.append( + [ + {"mask_data": p([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])}, # type: ignore + q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore + r([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), # type: ignore + ] + ) class TestMaskIntensity(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TEST_CASES) def test_value(self, argments, image, expected_data): result = MaskIntensity(**argments)(image) np.testing.assert_allclose(result, expected_data) From ada481faaa87a46b067d860e8efe20f42fe38df8 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 1 Jun 2021 19:17:16 +0100 Subject: [PATCH 015/176] gibbsnoise, and restore device Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 69 +++++++++++------------- monai/transforms/intensity/dictionary.py | 17 ++---- monai/transforms/post/array.py | 12 ++--- monai/transforms/spatial/array.py | 4 +- monai/transforms/transform.py | 23 +++++--- monai/transforms/utility/array.py | 20 +++---- tests/test_gibbs_noise.py | 15 +++--- 7 files changed, 74 insertions(+), 86 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index b9534d5bb9..e32b5152a7 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -829,7 +829,7 @@ def __call__(self, img: DataObjects.Images, mask_data: Optional[DataObjects.Imag "When mask_data is not single channel, mask_data channels must match img, " f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}." ) - mask_data_, _ = convert_data_type(mask_data_, type(img)) + mask_data_, *_ = convert_data_type(mask_data_, type(img)) return img * mask_data_ @@ -1161,10 +1161,9 @@ class RandGibbsNoise(RandomizableTransform): values in the interval [0,1] with alpha = 0 acting as the identity mapping. If a length-2 list is given as [a,b] then the value of alpha will be sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. """ - def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_tensor_output: bool = True) -> None: + def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0)) -> None: if len(alpha) != 2: raise AssertionError("alpha length must be 2.") @@ -1175,7 +1174,6 @@ def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_te self.alpha = alpha self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() - self.as_tensor_output = as_tensor_output RandomizableTransform.__init__(self, prob=prob) @@ -1186,13 +1184,8 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: if self._do_transform: # apply transform - transform = GibbsNoise(self.sampled_alpha, self.as_tensor_output) + transform = GibbsNoise(self.sampled_alpha) img = transform(img) - else: - if isinstance(img, np.ndarray) and self.as_tensor_output: - img = torch.Tensor(img) - elif isinstance(img, torch.Tensor) and not self.as_tensor_output: - img = img.detach().cpu().numpy() return img def _randomize(self, _: Any) -> None: @@ -1204,7 +1197,7 @@ def _randomize(self, _: Any) -> None: self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) -class GibbsNoise(Transform): +class GibbsNoise(TorchOrNumpyTransform): """ The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts are one of the common type of type artifacts appearing in MRI scans. @@ -1219,61 +1212,56 @@ class GibbsNoise(Transform): Args: alpha (float): Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. - """ - def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None: + def __init__(self, alpha: float = 0.5) -> None: if alpha > 1 or alpha < 0: raise AssertionError("alpha must take values in the interval [0,1].") self.alpha = alpha - self.as_tensor_output = as_tensor_output - self._device = torch.device("cpu") def __call__(self, img: DataObjects.Images) -> DataObjects.Images: n_dims = len(img.shape[1:]) - - # convert to ndarray to work with np.fft - _device = None - if isinstance(img, torch.Tensor): - _device = img.device - img = img.cpu().detach().numpy() - + data_type = type(img) # FT - k = self._shift_fourier(img, n_dims) + k = self._shift_fourier(img, n_dims, data_type) # build and apply mask - k = self._apply_mask(k) + k = self._apply_mask(k, data_type) # map back - img = self._inv_shift_fourier(k, n_dims) - return torch.Tensor(img).to(_device or self._device) if self.as_tensor_output else img + return self._inv_shift_fourier(k, n_dims, data_type) - def _shift_fourier(self, x: DataObjects.Images, n_dims: int) -> np.ndarray: + def _shift_fourier(self, x: DataObjects.Images, n_dims: int, data_type: type) -> DataObjects.Images: """ Applies fourier transform and shifts its output. Only the spatial dimensions get transformed. Args: - x (np.ndarray): tensor to fourier transform. + x: tensor to fourier transform. """ - out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) + # argument is dim if torch, else axes + mod, arg = (torch, "dim") if data_type is torch.Tensor else (np, "axes") + arg_dict = {arg: tuple(range(-n_dims, 0))} + out: DataObjects.Images = mod.fft.fftshift(mod.fft.fftn(x, **arg_dict), **arg_dict) # type: ignore return out - def _inv_shift_fourier(self, k: DataObjects.Images, n_dims: int) -> np.ndarray: + def _inv_shift_fourier(self, k: DataObjects.Images, n_dims: int, data_type: type) -> DataObjects.Images: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. """ - out: np.ndarray = np.fft.ifftn( - np.fft.ifftshift(k, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)) - ).real - return out + dims = tuple(range(-n_dims, 0)) + out: DataObjects.Images + if data_type is torch.Tensor: + out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward") + else: + out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims) + return out.real - def _apply_mask(self, k: np.ndarray) -> np.ndarray: + def _apply_mask(self, k: DataObjects.Images, data_type: type) -> DataObjects.Images: """Builds and applies a mask on the spatial dimensions. Args: - k (np.ndarray): k-space version of the image. + k: k-space version of the image. Returns: masked version of the k-space image. """ @@ -1294,6 +1282,9 @@ def _apply_mask(self, k: np.ndarray) -> np.ndarray: # add channel dimension into mask mask = np.repeat(mask[None], k.shape[0], axis=0) + if data_type is torch.Tensor: + mask = torch.Tensor(mask) + # apply binary mask - k_masked: np.ndarray = k * mask - return k_masked + out: DataObjects.Images = k * mask # type: ignore + return out diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 4199fe0c44..1deedfc394 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1041,7 +1041,6 @@ class RandGibbsNoised(RandomizableTransform, MapTransform): values in the interval [0,1] with alpha = 0 acting as the identity mapping. If a length-2 list is given as [a,b] then the value of alpha will be sampled uniformly from the interval [a,b]. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. allow_missing_keys: do not raise exception if key is missing. """ @@ -1050,7 +1049,6 @@ def __init__( keys: KeysCollection, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), - as_tensor_output: bool = True, allow_missing_keys: bool = False, ) -> None: @@ -1058,7 +1056,6 @@ def __init__( RandomizableTransform.__init__(self, prob=prob) self.alpha = alpha self.sampled_alpha = -1.0 # stores last alpha sampled by randomize() - self.as_tensor_output = as_tensor_output def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: @@ -1068,13 +1065,8 @@ def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Imag for i, key in enumerate(self.key_iterator(d)): if self._do_transform: if i == 0: - transform = GibbsNoise(self.sampled_alpha, self.as_tensor_output) + transform = GibbsNoise(self.sampled_alpha) d[key] = transform(d[key]) - else: - if isinstance(d[key], np.ndarray) and self.as_tensor_output: - d[key] = torch.Tensor(d[key]) - elif isinstance(d[key], torch.Tensor) and not self.as_tensor_output: - d[key] = self._to_numpy(d[key]) return d def _randomize(self, _: Any) -> None: @@ -1107,16 +1099,13 @@ class GibbsNoised(MapTransform): you need to transform. alpha (float): Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. - as_tensor_output: if true return torch.Tensor, else return np.array. default: True. allow_missing_keys: do not raise exception if key is missing. """ - def __init__( - self, keys: KeysCollection, alpha: float = 0.5, as_tensor_output: bool = True, allow_missing_keys: bool = False - ) -> None: + def __init__(self, keys: KeysCollection, alpha: float = 0.5, allow_missing_keys: bool = False) -> None: MapTransform.__init__(self, keys, allow_missing_keys) - self.transform = GibbsNoise(alpha, as_tensor_output) + self.transform = GibbsNoise(alpha) def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 05d9f3717a..e8a9917f9f 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -89,7 +89,7 @@ def __call__( if other is not None and not callable(other): raise TypeError(f"other must be None or callable but is {type(other).__name__}.") - img, orig_type = self.pre_conv_data(img) + img, orig_type, orig_device = self.pre_conv_data(img) # convert to float as activation must operate on float tensor img = img.float() # type: ignore @@ -105,7 +105,7 @@ def __call__( if act_func is not None: img = act_func(img) - return self.post_convert_data(img, orig_type) + return self.post_convert_data(img, orig_type, orig_device) class AsDiscrete(TorchTransform): @@ -169,7 +169,7 @@ def __call__( """ img_t: torch.Tensor - img_t, orig_type = self.pre_conv_data(img) # type: ignore + img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore if argmax or self.argmax: img_t = torch.argmax(img_t, dim=1, keepdim=True) @@ -183,7 +183,7 @@ def __call__( if threshold_values or self.threshold_values: img_t = img_t >= (logit_thresh or self.logit_thresh) - return self.post_convert_data(img_t.float(), orig_type) + return self.post_convert_data(img_t.float(), orig_type, orig_device) class KeepLargestConnectedComponent(TorchTransform): @@ -260,7 +260,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: Returns: A PyTorch Tensor with shape (batch_size, C, spatial_dim1[, spatial_dim2, ...]). """ - img, orig_type = self.pre_conv_data(img) + img, orig_type, orig_device = self.pre_conv_data(img) channel_dim = 1 if img.shape[channel_dim] == 1: @@ -296,7 +296,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: img[:, self.applied_labels, ...] = applied_img.type(img.type()) # type: ignore output = img # type: ignore - return self.post_convert_data(output, orig_type) + return self.post_convert_data(output, orig_type, orig_device) class LabelToContour(Transform): diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c6e7f12332..6f38aea0c3 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -323,11 +323,11 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - img, orig_type = self.pre_conv_data(img) + img, orig_type, orig_device = self.pre_conv_data(img) result = torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)).to(img.dtype) # type: ignore - return self.post_convert_data(result, orig_type) + return self.post_convert_data(result, orig_type, orig_device) class Resize(ToDoTransform): diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 9d0e81331f..5f0b85b7ae 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -36,16 +36,23 @@ ] -def convert_data_type(data, output_type: type) -> Tuple[DataObjects.Images, type]: +def convert_data_type( + data, output_type: type, device: Optional[torch.device] = None +) -> Tuple[DataObjects.Images, type, Optional[torch.device]]: """Convert to torch.Tensor/np.ndarray.""" orig_type = type(data) assert orig_type in (torch.Tensor, np.ndarray), f"Expected types {(torch.Tensor, np.ndarray)}, got {orig_type}" + orig_device = data.device if isinstance(data, torch.Tensor) else None + if orig_type is np.ndarray and output_type is torch.Tensor: data = torch.Tensor(np.ascontiguousarray(data)) elif orig_type is torch.Tensor and output_type is np.ndarray: data = data.detach().cpu().numpy() - return data, orig_type + if isinstance(data, torch.Tensor) and device is not None: + data.to(device) + + return data, orig_type, orig_device def apply_transform(transform: Callable, data, map_items: bool = True): @@ -222,17 +229,19 @@ def __call__(self, data: Any): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def pre_conv_data(self, data: DataObjects.Images) -> Tuple[DataObjects.Images, type]: + def pre_conv_data(self, data: DataObjects.Images) -> Tuple[DataObjects.Images, type, Optional[torch.device]]: """Convert to torch/numpy, as required. Also return the original state so that after the transform, the data can be reverted to its original type. """ - data, orig_type = convert_data_type(data, self.array_class) - return data, orig_type + data, orig_type, orig_device = convert_data_type(data, self.array_class) + return data, orig_type, orig_device @staticmethod - def post_convert_data(data: DataObjects.Images, output_type: type) -> DataObjects.Images: + def post_convert_data( + data: DataObjects.Images, output_type: type, ouput_device: Optional[torch.device] = None + ) -> DataObjects.Images: """Convert back to original type.""" - data, _ = convert_data_type(data, output_type) + data, *_ = convert_data_type(data, output_type, ouput_device) return data diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index b97e93c8f9..c51fe16ff3 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -106,9 +106,9 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - img, orig_type = self.pre_conv_data(img) + img, orig_type, orig_device = self.pre_conv_data(img) img = torch.moveaxis(img, self.channel_dim, 0) # type: ignore - return self.post_convert_data(img, orig_type) + return self.post_convert_data(img, orig_type, orig_device) class AsChannelLast(TorchTransform): @@ -135,9 +135,9 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - img, orig_type = self.pre_conv_data(img) + img, orig_type, orig_device = self.pre_conv_data(img) img = torch.moveaxis(img, self.channel_dim, -1) # type: ignore - return self.post_convert_data(img, orig_type) + return self.post_convert_data(img, orig_type, orig_device) class AddChannel(TorchOrNumpyTransform): @@ -319,7 +319,7 @@ def __call__(self, img: Union[DataObjects.Images, Sequence]) -> torch.Tensor: if isinstance(img, Sequence): img = torch.Tensor(img) else: - img, _ = convert_data_type(img, torch.Tensor) + img, *_ = convert_data_type(img, torch.Tensor) img = img.contiguous() # type: ignore return img @@ -336,7 +336,7 @@ def __call__(self, img: Union[DataObjects.Images, Sequence]) -> np.ndarray: if isinstance(img, Sequence): img = np.array(img) else: - img, _ = convert_data_type(img, np.ndarray) + img, *_ = convert_data_type(img, np.ndarray) img = np.ascontiguousarray(img) return img @@ -381,9 +381,9 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - img, orig_type = self.pre_conv_data(img) + img, orig_type, orig_device = self.pre_conv_data(img) img = img.transpose(self.indices) # type: ignore - return self.post_convert_data(img, orig_type) + return self.post_convert_data(img, orig_type, orig_device) class SqueezeDim(TorchOrNumpyTransform): @@ -630,7 +630,7 @@ def __call__( merge_channels: whether to use `np.any()` to merge the result on channel dim. if yes, will return a single channel mask with binary data. """ - img, orig_type = self.pre_conv_data(img) + img, orig_type, orig_device = self.pre_conv_data(img) if select_labels is None: select_labels = self.select_labels @@ -644,7 +644,7 @@ def __call__( out = np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data - return self.post_convert_data(out, orig_type) # type: ignore + return self.post_convert_data(out, *orig_info) # type: ignore class FgBgToIndices(Transform): diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index c853d4686a..c174e96657 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -22,9 +22,8 @@ TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for as_tensor_input in (True, False): + TEST_CASES.append((shape, as_tensor_input)) class TestGibbsNoise(unittest.TestCase): @@ -42,17 +41,17 @@ def get_data(im_shape, as_tensor_input): return torch.Tensor(im) if as_tensor_input else im @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + def test_same_result(self, im_shape, as_tensor_input): im = self.get_data(im_shape, as_tensor_input) alpha = 0.8 - t = GibbsNoise(alpha, as_tensor_output) + t = GibbsNoise(alpha) out1 = t(deepcopy(im)) out2 = t(deepcopy(im)) np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) + self.assertIsInstance(out1, type(im)) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): + def test_identity(self, im_shape, as_tensor_input): im = self.get_data(im_shape, as_tensor_input) alpha = 0.0 t = GibbsNoise(alpha) @@ -60,7 +59,7 @@ def test_identity(self, im_shape, _, as_tensor_input): np.testing.assert_allclose(im, out, atol=1e-2) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): + def test_alpha_1(self, im_shape, as_tensor_input): im = self.get_data(im_shape, as_tensor_input) alpha = 1.0 t = GibbsNoise(alpha) From 12d8ac7041a1381d7d5c20b3c381998a86af3ce9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Jun 2021 12:26:11 +0100 Subject: [PATCH 016/176] rotate90, randrotate90 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 15 ++++---- tests/test_rand_rotate90.py | 61 +++++++++++++++++-------------- tests/test_rotate90.py | 53 +++++++++++++++------------ 3 files changed, 70 insertions(+), 59 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 6f38aea0c3..f84de05427 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -28,6 +28,7 @@ RandomizableTransform, ThreadUnsafe, ToDoTransform, + TorchOrNumpyTransform, TorchTransform, Transform, ) @@ -597,7 +598,7 @@ def __call__( return zoomed[tuple(slice_vec)] -class Rotate90(ToDoTransform): +class Rotate90(TorchOrNumpyTransform): """ Rotate an array by 90 degrees in the plane specified by `axes`. See np.rot90 for additional details: @@ -619,17 +620,17 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - - result: np.ndarray = np.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)) - return result.astype(img.dtype) + if isinstance(img, torch.Tensor): + return torch.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)).to(img.dtype) + return np.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)).astype(img.dtype) # type: ignore -class RandRotate90(ToDoTransform, RandomizableTransform): +class RandRotate90(TorchOrNumpyTransform, RandomizableTransform): """ With probability `prob`, input arrays are rotated by 90 degrees in the plane specified by `spatial_axes`. @@ -654,7 +655,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 50a1b28e53..478b32b7ed 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from monai.transforms import RandRotate90 from tests.utils import NumpyImageTestCase2D @@ -20,43 +21,47 @@ class TestRandRotate90(NumpyImageTestCase2D): def test_default(self): rotate = RandRotate90() - rotate.set_random_state(123) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in (torch.Tensor, np.array): + rotate.set_random_state(123) + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated, expected)) def test_k(self): rotate = RandRotate90(max_k=2) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in (torch.Tensor, np.array): + rotate.set_random_state(234) + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated, expected)) def test_spatial_axes(self): rotate = RandRotate90(spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in (torch.Tensor, np.array): + rotate.set_random_state(234) + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 0, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated, expected)) def test_prob_k_spatial_axes(self): rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1)) - rotate.set_random_state(234) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in (torch.Tensor, np.array): + rotate.set_random_state(234) + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated, expected)) if __name__ == "__main__": diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 4ab39d5cf6..30c20820b4 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from monai.transforms import Rotate90 from tests.utils import NumpyImageTestCase2D @@ -20,39 +21,43 @@ class TestRotate90(NumpyImageTestCase2D): def test_rotate90_default(self): rotate = Rotate90() - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in (torch.Tensor, np.array): + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated, expected)) def test_k(self): rotate = Rotate90(k=2) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in (torch.Tensor, np.array): + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated, expected)) def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 1, (0, -1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in (torch.Tensor, np.array): + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 1, (0, -1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated, expected)) def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) - rotated = rotate(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + for p in (torch.Tensor, np.array): + rotated = rotate(p(self.imt[0])) + expected = [] + for channel in self.imt[0]: + expected.append(np.rot90(channel, 2, (0, 1))) + expected = np.stack(expected) + self.assertTrue(np.allclose(rotated, expected)) if __name__ == "__main__": From 5f8c1c7ebd44c8c327ff779fc0ccc7664309f71e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Jun 2021 12:31:54 +0100 Subject: [PATCH 017/176] rand_gibbs_noise Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 2 +- tests/test_rand_gibbs_noise.py | 21 ++++++++++----------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index e32b5152a7..4f4378e35e 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1142,7 +1142,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: ) -class RandGibbsNoise(RandomizableTransform): +class RandGibbsNoise(TorchOrNumpyTransform, RandomizableTransform): """ Naturalistic image augmentation via Gibbs artifacts. The transform randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index ef2fe25eb4..0f8dafdeb8 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -22,9 +22,8 @@ TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for as_tensor_input in (True, False): + TEST_CASES.append((shape, as_tensor_input)) class TestRandGibbsNoise(unittest.TestCase): @@ -42,27 +41,27 @@ def get_data(im_shape, as_tensor_input): return torch.Tensor(im) if as_tensor_input else im @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): + def test_0_prob(self, im_shape, as_tensor_input): im = self.get_data(im_shape, as_tensor_input) alpha = [0.5, 1.0] - t = RandGibbsNoise(0.0, alpha, as_tensor_output) + t = RandGibbsNoise(0.0, alpha) out = t(im) np.testing.assert_allclose(im, out) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + def test_same_result(self, im_shape, as_tensor_input): im = self.get_data(im_shape, as_tensor_input) alpha = [0.5, 0.8] - t = RandGibbsNoise(1.0, alpha, as_tensor_output) + t = RandGibbsNoise(1.0, alpha) t.set_random_state(42) out1 = t(deepcopy(im)) t.set_random_state(42) out2 = t(deepcopy(im)) np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) + self.assertIsInstance(out1, type(im)) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): + def test_identity(self, im_shape, as_tensor_input): im = self.get_data(im_shape, as_tensor_input) alpha = [0.0, 0.0] t = RandGibbsNoise(1.0, alpha) @@ -70,7 +69,7 @@ def test_identity(self, im_shape, _, as_tensor_input): np.testing.assert_allclose(im, out, atol=1e-2) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): + def test_alpha_1(self, im_shape, as_tensor_input): im = self.get_data(im_shape, as_tensor_input) alpha = [1.0, 1.0] t = RandGibbsNoise(1.0, alpha) @@ -78,7 +77,7 @@ def test_alpha_1(self, im_shape, _, as_tensor_input): np.testing.assert_allclose(0 * im, out) @parameterized.expand(TEST_CASES) - def test_alpha(self, im_shape, _, as_tensor_input): + def test_alpha(self, im_shape, as_tensor_input): im = self.get_data(im_shape, as_tensor_input) alpha = [0.5, 0.51] t = RandGibbsNoise(1.0, alpha) From 0ec1ff16473b4e9624b446d955905a06c78db103 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Jun 2021 16:12:27 +0100 Subject: [PATCH 018/176] crop_foreground Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 13 +++-- tests/test_crop_foreground.py | 91 ++++++++++++++++++------------- 2 files changed, 63 insertions(+), 41 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index fbb4741535..6369cc2c9f 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -21,7 +21,7 @@ from monai.config import IndexSelection from monai.data.utils import get_random_patch, get_valid_patch_size -from monai.transforms.transform import Randomizable, TorchOrNumpyTransform, Transform +from monai.transforms.transform import NumpyTransform, Randomizable, TorchOrNumpyTransform, Transform, convert_data_type from monai.transforms.utils import ( compute_divisible_spatial_size, generate_pos_neg_label_crop_centers, @@ -471,7 +471,7 @@ def __call__(self, img: np.ndarray) -> List[np.ndarray]: return [self.cropper(img) for _ in range(self.num_samples)] -class CropForeground(Transform): +class CropForeground(NumpyTransform): """ Crop an image using a bounding box. The bounding box is generated by selecting foreground using select_fn at channels channel_indices. margin is added in each spatial dimension of the bounding box. @@ -534,12 +534,13 @@ def __init__( self.k_divisible = k_divisible self.mode: NumpyPadMode = NumpyPadMode(mode) - def compute_bounding_box(self, img: np.ndarray): + def compute_bounding_box(self, img: DataObjects.Images) -> Tuple[np.ndarray, np.ndarray]: """ Compute the start points and end points of bounding box to crop. And adjust bounding box coords to be divisible by `k`. """ + img, *_ = convert_data_type(img, np.ndarray) box_start, box_end = generate_spatial_bounding_box(img, self.select_fn, self.channel_indices, self.margin) box_start_ = np.asarray(box_start, dtype=np.int16) box_end_ = np.asarray(box_end, dtype=np.int16) @@ -562,14 +563,18 @@ def crop_pad(self, img: np.ndarray, box_start: np.ndarray, box_end: np.ndarray): pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) return BorderPad(spatial_border=pad, mode=self.mode)(cropped) - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. """ + img, orig_type, orig_device = self.pre_conv_data(img) + box_start, box_end = self.compute_bounding_box(img) cropped = self.crop_pad(img, box_start, box_end) + cropped = self.post_convert_data(cropped, orig_type, orig_device) + if self.return_coords: return cropped, box_start, box_end return cropped diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 8eae8f484e..1e7bf3bca1 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -12,60 +12,77 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import CropForeground -TEST_CASE_1 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] +TESTS = [] +for p in (torch.Tensor, np.array): + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), # type: ignore + np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) -TEST_CASE_2 = [ - {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[3]]]), -] + TESTS.append( + [ + {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), # type: ignore + np.array([[[3]]]), + ] + ) -TEST_CASE_3 = [ - {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), # type: ignore + np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) -TEST_CASE_4 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), -] + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), # type: ignore + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_5 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), -] + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), # type: ignore + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_6 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), -] + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), # type: ignore + np.array([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), + ] + ) -TEST_CASE_7 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10}, - np.array([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.zeros((1, 0, 0)), -] + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10}, + p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), # type: ignore + np.zeros((1, 0, 0)), + ] + ) class TestCropForeground(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = CropForeground(**argments)(image) np.testing.assert_allclose(result, expected_data) - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand([TESTS[0]]) def test_return_coords(self, argments, image, _): argments["return_coords"] = True _, start_coord, end_coord = CropForeground(**argments)(image) From 207d18f594596d9658e367d7be0248c4994a59e5 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Jun 2021 16:23:15 +0100 Subject: [PATCH 019/176] format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 6369cc2c9f..058da71424 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -540,8 +540,9 @@ def compute_bounding_box(self, img: DataObjects.Images) -> Tuple[np.ndarray, np. And adjust bounding box coords to be divisible by `k`. """ - img, *_ = convert_data_type(img, np.ndarray) - box_start, box_end = generate_spatial_bounding_box(img, self.select_fn, self.channel_indices, self.margin) + img_np: np.ndarray + img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore + box_start, box_end = generate_spatial_bounding_box(img_np, self.select_fn, self.channel_indices, self.margin) box_start_ = np.asarray(box_start, dtype=np.int16) box_end_ = np.asarray(box_end, dtype=np.int16) orig_spatial_size = box_end_ - box_start_ @@ -563,15 +564,16 @@ def crop_pad(self, img: np.ndarray, box_start: np.ndarray, box_end: np.ndarray): pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) return BorderPad(spatial_border=pad, mode=self.mode)(cropped) - def __call__(self, img: DataObjects.Images) -> DataObjects.Images: + def __call__(self, img: DataObjects.Images): """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. """ - img, orig_type, orig_device = self.pre_conv_data(img) + img_np: np.ndarray + img_np, orig_type, orig_device = self.pre_conv_data(img) # type: ignore - box_start, box_end = self.compute_bounding_box(img) - cropped = self.crop_pad(img, box_start, box_end) + box_start, box_end = self.compute_bounding_box(img_np) + cropped = self.crop_pad(img_np, box_start, box_end) cropped = self.post_convert_data(cropped, orig_type, orig_device) From 1b578a2ddf783a280cb95ffbeec0025b572fb87d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Jun 2021 17:04:52 +0100 Subject: [PATCH 020/176] torch_seed and rician_noise Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 23 +++++++++++++---------- monai/transforms/transform.py | 3 +++ tests/test_rand_rician_noise.py | 21 +++++++++++++-------- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 4f4378e35e..54673e0646 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -95,7 +95,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img + self._noise.astype(dtype) -class RandRicianNoise(RandomizableTransform): +class RandRicianNoise(TorchOrNumpyTransform, RandomizableTransform): """ Add Rician noise to image. Rician noise in MRI is the result of performing a magnitude operation on complex @@ -134,18 +134,21 @@ def __init__( self.channel_wise = channel_wise self.relative = relative self.sample_std = sample_std - self._noise1 = None - self._noise2 = None + self._noise1: Optional[DataObjects.Images] = None + self._noise2: Optional[DataObjects.Images] = None def _add_noise(self, img: DataObjects.Images, mean: float, std: float): im_shape = img.shape - _std = self.R.uniform(0, std) if self.sample_std else std - self._noise1 = self.R.normal(mean, _std, size=im_shape) - self._noise2 = self.R.normal(mean, _std, size=im_shape) - if self._noise1 is None or self._noise2 is None: - raise AssertionError - dtype = dtype_torch_to_numpy(img.dtype) if isinstance(img, torch.Tensor) else img.dtype - return np.sqrt((img + self._noise1.astype(dtype)) ** 2 + self._noise2.astype(dtype) ** 2) + if isinstance(img, torch.Tensor): + _std = float(torch.rand(1, generator=self.R_torch)) * std if self.sample_std else std + self._noise1 = torch.normal(mean, _std, size=im_shape, generator=self.R_torch).to(img.dtype) + self._noise2 = torch.normal(mean, _std, size=im_shape, generator=self.R_torch).to(img.dtype) + return torch.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) + else: + _std = self.R.uniform(0, std) if self.sample_std else std + self._noise1 = self.R.normal(mean, _std, size=im_shape).astype(img.dtype) + self._noise2 = self.R.normal(mean, _std, size=im_shape).astype(img.dtype) + return np.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 5f0b85b7ae..9418485c60 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -128,6 +128,7 @@ class Randomizable(ABC, ThreadUnsafe): """ R: np.random.RandomState = np.random.RandomState() + R_torch = torch.Generator() def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -152,6 +153,7 @@ def set_random_state( _seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed _seed = _seed % MAX_SEED self.R = np.random.RandomState(_seed) + self.R_torch = self.R_torch.manual_seed(_seed) return self if state is not None: @@ -161,6 +163,7 @@ def set_random_state( return self self.R = np.random.RandomState() + self.R_torch = torch.Generator() return self def randomize(self, data: Any) -> None: diff --git a/tests/test_rand_rician_noise.py b/tests/test_rand_rician_noise.py index 6504fd9069..28c97abb6c 100644 --- a/tests/test_rand_rician_noise.py +++ b/tests/test_rand_rician_noise.py @@ -13,13 +13,19 @@ import numpy as np from parameterized import parameterized +import torch from monai.transforms import RandRicianNoise from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D +TESTS = [ + ("test_zero_mean", 0, 0.1), + ("test_non_zero_mean", 1, 0.5), +] + class TestRandRicianNoise(NumpyImageTestCase2D): - @parameterized.expand([("test_zero_mean", 0, 0.1), ("test_non_zero_mean", 1, 0.5)]) + @parameterized.expand(TESTS) def test_correct_results(self, _, mean, std): seed = 0 rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std) @@ -36,18 +42,17 @@ def test_correct_results(self, _, mean, std): class TestRandRicianNoiseTorch(TorchImageTestCase2D): - @parameterized.expand([("test_zero_mean", 0, 0.1), ("test_non_zero_mean", 1, 0.5)]) + @parameterized.expand(TESTS) def test_correct_results(self, _, mean, std): seed = 0 rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std) rician_fn.set_random_state(seed) noised = rician_fn(self.imt) - np.random.seed(seed) - np.random.random() - _std = np.random.uniform(0, std) - expected = np.sqrt( - (self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2 - + np.random.normal(mean, _std, size=self.imt.shape) ** 2 + torch.manual_seed(seed) + _std = float(torch.rand(1)) * std + expected = torch.sqrt( + (self.imt + torch.normal(mean, _std, size=self.imt.shape)) ** 2 + + torch.normal(mean, _std, size=self.imt.shape) ** 2 ) np.testing.assert_allclose(expected, noised, atol=1e-5) From c0aad878a9d2c9d4c34ec0f6fb10d2b541b93b99 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Jun 2021 17:12:48 +0100 Subject: [PATCH 021/176] format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 6 +++--- monai/transforms/intensity/array.py | 2 +- tests/test_rand_rician_noise.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index d1f642948e..e5c3df9b64 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -687,7 +687,7 @@ def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: ret.append(cropped) return ret - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) # We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd # Need to revert that since we're calling RandSpatialCropd's inverse @@ -885,7 +885,7 @@ def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) @@ -1042,7 +1042,7 @@ def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: return results - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 54673e0646..605d798b0d 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -148,7 +148,7 @@ def _add_noise(self, img: DataObjects.Images, mean: float, std: float): _std = self.R.uniform(0, std) if self.sample_std else std self._noise1 = self.R.normal(mean, _std, size=im_shape).astype(img.dtype) self._noise2 = self.R.normal(mean, _std, size=im_shape).astype(img.dtype) - return np.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) + return np.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) # type: ignore def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ diff --git a/tests/test_rand_rician_noise.py b/tests/test_rand_rician_noise.py index 28c97abb6c..4779895fdd 100644 --- a/tests/test_rand_rician_noise.py +++ b/tests/test_rand_rician_noise.py @@ -12,8 +12,8 @@ import unittest import numpy as np -from parameterized import parameterized import torch +from parameterized import parameterized from monai.transforms import RandRicianNoise from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D From 9d75e05bdc9b1e5aeba2efdecdd7a5aed14c728f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Jun 2021 17:23:30 +0100 Subject: [PATCH 022/176] seed as int Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 9418485c60..bd2c04d73d 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -153,7 +153,7 @@ def set_random_state( _seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed _seed = _seed % MAX_SEED self.R = np.random.RandomState(_seed) - self.R_torch = self.R_torch.manual_seed(_seed) + self.R_torch = self.R_torch.manual_seed(int(_seed)) return self if state is not None: From 20283d0921947d2cc86693115ee31fe73409f3de Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Jun 2021 17:33:40 +0100 Subject: [PATCH 023/176] inverses Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index e5c3df9b64..32aea4c294 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -37,7 +37,7 @@ SpatialPad, ) from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.transform import MapTransform, Randomizable, convert_data_type from monai.transforms.utils import ( allow_missing_keys_mode, generate_pos_neg_label_crop_centers, @@ -852,7 +852,7 @@ def randomize(self, weight_map: np.ndarray) -> None: def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: d = dict(data) - self.randomize(d[self.w_key]) + self.randomize(convert_data_type(d[self.w_key], np.ndarray)[0]) _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) results: List[DataObjects.Dict] = [{} for _ in range(self.num_samples)] @@ -1011,10 +1011,14 @@ def randomize( def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: d = dict(data) - label = d[self.label_key] - image = d[self.image_key] if self.image_key else None - fg_indices = d.get(self.fg_indices_key) if self.fg_indices_key is not None else None - bg_indices = d.get(self.bg_indices_key) if self.bg_indices_key is not None else None + label = convert_data_type(d[self.label_key], np.ndarray)[0] + image = convert_data_type(d[self.image_key], np.ndarray)[0] if self.image_key else None + fg_indices = ( + convert_data_type(d.get(self.fg_indices_key), np.ndarray)[0] if self.fg_indices_key is not None else None + ) + bg_indices = ( + convert_data_type(d.get(self.bg_indices_key), np.ndarray)[0] if self.bg_indices_key is not None else None + ) self.randomize(label, fg_indices, bg_indices, image) if not isinstance(self.spatial_size, tuple): From 754308b922f03b9be9adcb3a6993c323793c7268 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Jun 2021 17:41:38 +0100 Subject: [PATCH 024/176] gaussian_smooth and rand_gaussian_smooth Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 23 +++++--- tests/test_gaussian_smooth.py | 87 +++++++++++++++++++---------- 2 files changed, 74 insertions(+), 36 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 605d798b0d..ec984eebd6 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -22,7 +22,13 @@ from monai.config import DtypeLike from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter -from monai.transforms.transform import RandomizableTransform, TorchOrNumpyTransform, Transform, convert_data_type +from monai.transforms.transform import ( + RandomizableTransform, + TorchOrNumpyTransform, + TorchTransform, + Transform, + convert_data_type, +) from monai.transforms.utils import rescale_array from monai.utils import ( PT_BEFORE_1_7, @@ -914,7 +920,7 @@ def __call__(self, img: np.ndarray): return np.abs(hilbert_transform(input_data).squeeze(0).numpy()) -class GaussianSmooth(Transform): +class GaussianSmooth(TorchTransform): """ Apply Gaussian smooth to the input data based on specified `sigma` parameter. A default value `sigma=1.0` is provided for reference. @@ -932,13 +938,16 @@ def __init__(self, sigma: Union[Sequence[float], float] = 1.0, approx: str = "er self.sigma = sigma self.approx = approx - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: + img_t: torch.Tensor + img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx) - input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0) - return gaussian_filter(input_data).squeeze(0).detach().numpy() + out = gaussian_filter(img_t.unsqueeze(0)).squeeze(0) + return self.post_convert_data(out, orig_type, orig_device) -class RandGaussianSmooth(RandomizableTransform): +class RandGaussianSmooth(TorchTransform, RandomizableTransform): """ Apply Gaussian smooth to the input data based on randomly selected `sigma` parameters. @@ -976,7 +985,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1]) self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1]) - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: self.randomize() if not self._do_transform: return img diff --git a/tests/test_gaussian_smooth.py b/tests/test_gaussian_smooth.py index e51977fbee..896b0ee5bc 100644 --- a/tests/test_gaussian_smooth.py +++ b/tests/test_gaussian_smooth.py @@ -12,50 +12,79 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import GaussianSmooth -TEST_CASE_1 = [ - {"sigma": 1.5}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( +TESTS = [] +for p in (torch.Tensor, np.array): + TESTS.append( [ - [ - [0.59167546, 0.69312394, 0.59167546], - [0.7956997, 0.93213004, 0.7956997], - [0.7668002, 0.8982755, 0.7668002], - ], - [[1.6105323, 1.8866735, 1.6105323], [1.9892492, 2.3303251, 1.9892492], [1.7856569, 2.091825, 1.7856569]], + {"sigma": 1.5}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore + np.array( + [ + [ + [0.59167546, 0.69312394, 0.59167546], + [0.7956997, 0.93213004, 0.7956997], + [0.7668002, 0.8982755, 0.7668002], + ], + [ + [1.6105323, 1.8866735, 1.6105323], + [1.9892492, 2.3303251, 1.9892492], + [1.7856569, 2.091825, 1.7856569], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"sigma": 0.5}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[0.8424794, 0.99864554, 0.8424794], [1.678146, 1.9892154, 1.678146], [1.9889624, 2.3576462, 1.9889624]], - [[2.966061, 3.5158648, 2.966061], [4.1953645, 4.973038, 4.1953645], [4.112544, 4.8748655, 4.1125436]], + {"sigma": 0.5}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore + np.array( + [ + [ + [0.8424794, 0.99864554, 0.8424794], + [1.678146, 1.9892154, 1.678146], + [1.9889624, 2.3576462, 1.9889624], + ], + [ + [2.966061, 3.5158648, 2.966061], + [4.1953645, 4.973038, 4.1953645], + [4.112544, 4.8748655, 4.1125436], + ], + ] + ), ] - ), -] + ) -TEST_CASE_3 = [ - {"sigma": [1.5, 0.5]}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + TESTS.append( [ - [[0.8542037, 1.0125432, 0.8542037], [1.1487541, 1.3616928, 1.1487541], [1.1070318, 1.3122368, 1.1070318]], - [[2.3251305, 2.756128, 2.3251305], [2.8718853, 3.4042323, 2.8718853], [2.5779586, 3.0558217, 2.5779586]], + {"sigma": [1.5, 0.5]}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore + np.array( + [ + [ + [0.8542037, 1.0125432, 0.8542037], + [1.1487541, 1.3616928, 1.1487541], + [1.1070318, 1.3122368, 1.1070318], + ], + [ + [2.3251305, 2.756128, 2.3251305], + [2.8718853, 3.4042323, 2.8718853], + [2.5779586, 3.0558217, 2.5779586], + ], + ] + ), ] - ), -] + ) class TestGaussianSmooth(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSmooth(**argments)(image) np.testing.assert_allclose(result, expected_data, rtol=1e-4) From 963fc48b1bdc1ff8676dada05e9211b6be50ae9d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Jun 2021 17:44:34 +0100 Subject: [PATCH 025/176] gibbs noise Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_rand_gibbs_noised.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index 82ce220d89..a37dd7fa93 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -64,7 +64,7 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): out2 = t(deepcopy(data)) for k in KEYS: np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) + self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) def test_identity(self, im_shape, _, as_tensor_input): From 7cea565f12faa3755deeefe80bfb0130e48ca384 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Jun 2021 17:53:08 +0100 Subject: [PATCH 026/176] gaussian sharpen Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 20 ++-- tests/test_gaussian_sharpen.py | 70 +++++++------- tests/test_rand_gaussian_sharpen.py | 137 ++++++++++++++-------------- 3 files changed, 119 insertions(+), 108 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index ec984eebd6..458e9b2fc2 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -993,7 +993,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return GaussianSmooth(sigma=sigma, approx=self.approx)(img) -class GaussianSharpen(Transform): +class GaussianSharpen(TorchTransform): """ Sharpen images using the Gaussian Blur filter. Referring to: http://scipy-lectures.org/advanced/image_processing/auto_examples/plot_sharpen.html. @@ -1032,16 +1032,18 @@ def __init__( self.alpha = alpha self.approx = approx - def __call__(self, img: np.ndarray): - gaussian_filter1 = GaussianFilter(img.ndim - 1, self.sigma1, approx=self.approx) - gaussian_filter2 = GaussianFilter(img.ndim - 1, self.sigma2, approx=self.approx) - input_data = torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0) - blurred_f = gaussian_filter1(input_data) - filter_blurred_f = gaussian_filter2(blurred_f) - return (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0).detach().numpy() + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: + img_t: torch.Tensor + img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + + gf1, gf2 = [GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx) for sigma in (self.sigma1, self.sigma2)] + blurred_f = gf1(img_t.unsqueeze(0)) + filter_blurred_f = gf2(blurred_f) + out = (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0) + return self.post_convert_data(out, orig_type, orig_device) -class RandGaussianSharpen(RandomizableTransform): +class RandGaussianSharpen(TorchTransform, RandomizableTransform): """ Sharpen images using the Gaussian Blur filter based on randomly selected `sigma1`, `sigma2` and `alpha`. The algorithm is :py:class:`monai.transforms.GaussianSharpen`. diff --git a/tests/test_gaussian_sharpen.py b/tests/test_gaussian_sharpen.py index 9d078e65e5..b4443cc3ed 100644 --- a/tests/test_gaussian_sharpen.py +++ b/tests/test_gaussian_sharpen.py @@ -13,45 +13,49 @@ import numpy as np from parameterized import parameterized +import torch from monai.transforms import GaussianSharpen -TEST_CASE_1 = [ - {}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[4.1081963, 3.4950666, 4.1081963], [3.7239995, 2.8491793, 3.7239995], [4.569839, 3.9529324, 4.569839]], - [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], - ] - ), -] - -TEST_CASE_2 = [ - {"sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[4.513644, 4.869134, 4.513644], [8.467242, 9.4004135, 8.467242], [10.416813, 12.0653515, 10.416813]], - [[15.711488, 17.569994, 15.711488], [21.16811, 23.501041, 21.16811], [21.614658, 24.766209, 21.614658]], - ] - ), -] - -TEST_CASE_3 = [ - {"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[3.3324685, 3.335536, 3.3324673], [7.7666636, 8.16056, 7.7666636], [12.662973, 14.317837, 12.6629715]], - [[15.329051, 16.57557, 15.329051], [19.41665, 20.40139, 19.416655], [24.659554, 27.557873, 24.659554]], - ] - ), -] +TESTS = [] + +for p in (torch.Tensor, np.array): + TESTS.append([ + {}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [[4.1081963, 3.4950666, 4.1081963], [3.7239995, 2.8491793, 3.7239995], [4.569839, 3.9529324, 4.569839]], + [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], + ] + ), + ]) + + TESTS.append([ + {"sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [[4.513644, 4.869134, 4.513644], [8.467242, 9.4004135, 8.467242], [10.416813, 12.0653515, 10.416813]], + [[15.711488, 17.569994, 15.711488], [21.16811, 23.501041, 21.16811], [21.614658, 24.766209, 21.614658]], + ] + ), + ]) + + TESTS.append([ + {"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [[3.3324685, 3.335536, 3.3324673], [7.7666636, 8.16056, 7.7666636], [12.662973, 14.317837, 12.6629715]], + [[15.329051, 16.57557, 15.329051], [19.41665, 20.40139, 19.416655], [24.659554, 27.557873, 24.659554]], + ] + ), + ]) class TestGaussianSharpen(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSharpen(**argments)(image) np.testing.assert_allclose(result, expected_data, rtol=1e-4) diff --git a/tests/test_rand_gaussian_sharpen.py b/tests/test_rand_gaussian_sharpen.py index 909f96f56b..6f690200a2 100644 --- a/tests/test_rand_gaussian_sharpen.py +++ b/tests/test_rand_gaussian_sharpen.py @@ -13,86 +13,91 @@ import numpy as np from parameterized import parameterized +import torch from monai.transforms import RandGaussianSharpen -TEST_CASE_1 = [ - {"prob": 1.0}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[5.2919216, 5.5854445, 5.29192], [11.3982, 12.62332, 11.398202], [14.870525, 17.323769, 14.870527]], - [[20.413757, 22.767355, 20.413757], [28.495504, 31.558315, 28.495499], [29.99236, 34.505676, 29.992361]], - ] - ), -] +TESTS = [] -TEST_CASE_2 = [ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": 0.4, - "sigma2_y": 0.4, - "sigma2_z": 0.4, - "prob": 1.0, - }, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[4.1071496, 3.597953, 4.1071477], [10.062014, 9.825114, 10.0620165], [14.698058, 15.818766, 14.698058]], - [[18.211048, 18.16049, 18.211048], [25.155039, 24.56279, 25.155039], [28.801964, 30.381308, 28.801964]], - ] - ), -] +for p in (torch.Tensor, np.array): + TESTS.append([ + {"prob": 1.0}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [[5.2919216, 5.5854445, 5.29192], [11.3982, 12.62332, 11.398202], [14.870525, 17.323769, 14.870527]], + [[20.413757, 22.767355, 20.413757], [28.495504, 31.558315, 28.495499], [29.99236, 34.505676, 29.992361]], + ] + ), + ]) -TEST_CASE_3 = [ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "prob": 1.0, - }, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[4.81077, 4.4237204, 4.81077], [12.061236, 12.298177, 12.061236], [17.362553, 19.201174, 17.362553]], - [[21.440754, 22.142393, 21.440754], [30.15308, 30.745445, 30.153086], [33.99255, 36.919838, 33.99255]], - ] - ), -] + TESTS.append([ + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": 0.4, + "sigma2_y": 0.4, + "sigma2_z": 0.4, + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [[4.1071496, 3.597953, 4.1071477], [10.062014, 9.825114, 10.0620165], [14.698058, 15.818766, 14.698058]], + [[18.211048, 18.16049, 18.211048], [25.155039, 24.56279, 25.155039], [28.801964, 30.381308, 28.801964]], + ] + ), + ]) -TEST_CASE_4 = [ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "approx": "scalespace", - "prob": 1.0, - }, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[4.430213, 3.2278745, 4.4302144], [10.325399, 8.507457, 10.325399], [17.494898, 16.5609, 17.494894]], - [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], - ] - ), -] + TESTS.append([ + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [[4.81077, 4.4237204, 4.81077], [12.061236, 12.298177, 12.061236], [17.362553, 19.201174, 17.362553]], + [[21.440754, 22.142393, 21.440754], [30.15308, 30.745445, 30.153086], [33.99255, 36.919838, 33.99255]], + ] + ), + ]) + + TESTS.append([ + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "approx": "scalespace", + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [[4.430213, 3.2278745, 4.4302144], [10.325399, 8.507457, 10.325399], [17.494898, 16.5609, 17.494894]], + [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], + ] + ), + ]) class TestRandGaussianSharpen(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): converter = RandGaussianSharpen(**argments) converter.set_random_state(seed=0) result = converter(image) np.testing.assert_allclose(result, expected_data, rtol=1e-4) + self.assertIsInstance(result, type(image)) if __name__ == "__main__": From e5da98ceae070fbdb207cb3f33651bb71fc002bb Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 3 Jun 2021 13:04:52 +0100 Subject: [PATCH 027/176] DetectEnvelope, test transforms w/ CUDA Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 19 +-- monai/transforms/intensity/array.py | 15 ++- monai/transforms/transform.py | 7 +- monai/utils/misc.py | 22 +-- tests/test_crop_foreground.py | 8 +- tests/test_detect_envelope.py | 13 +- tests/test_gaussian_sharpen.py | 97 +++++++++----- tests/test_gaussian_smooth.py | 9 +- tests/test_inverse.py | 3 +- tests/test_mask_intensity.py | 56 ++++---- tests/test_rand_gaussian_sharpen.py | 177 +++++++++++++++---------- tests/test_rand_rotate90.py | 14 +- tests/test_rotate90.py | 14 +- 13 files changed, 290 insertions(+), 164 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 32aea4c294..cf17b140e4 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -852,7 +852,8 @@ def randomize(self, weight_map: np.ndarray) -> None: def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: d = dict(data) - self.randomize(convert_data_type(d[self.w_key], np.ndarray)[0]) + weight_map: np.ndarray = convert_data_type(d[self.w_key], np.ndarray)[0] # type: ignore + self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) results: List[DataObjects.Dict] = [{} for _ in range(self.num_samples)] @@ -1011,14 +1012,14 @@ def randomize( def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: d = dict(data) - label = convert_data_type(d[self.label_key], np.ndarray)[0] - image = convert_data_type(d[self.image_key], np.ndarray)[0] if self.image_key else None - fg_indices = ( - convert_data_type(d.get(self.fg_indices_key), np.ndarray)[0] if self.fg_indices_key is not None else None - ) - bg_indices = ( - convert_data_type(d.get(self.bg_indices_key), np.ndarray)[0] if self.bg_indices_key is not None else None - ) + + def get_as_np(x) -> np.ndarray: + return convert_data_type(x, np.ndarray)[0] # type: ignore + + label = get_as_np(d[self.label_key]) + image = get_as_np(d[self.image_key]) if self.image_key else None + fg_indices = get_as_np(d[self.fg_indices_key]) if self.fg_indices_key is not None else None + bg_indices = get_as_np(d[self.bg_indices_key]) if self.bg_indices_key is not None else None self.randomize(label, fg_indices, bg_indices, image) if not isinstance(self.spatial_size, tuple): diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 458e9b2fc2..f27f5c87b2 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -880,7 +880,7 @@ def __call__(self, img: np.ndarray): return savgol_filter(input_data).squeeze(0).numpy() -class DetectEnvelope(Transform): +class DetectEnvelope(TorchTransform): """ Find the envelope of the input data along the requested axis using a Hilbert transform. Requires PyTorch 1.7.0+ and the PyTorch FFT module (which is not included in NVIDIA PyTorch Release 20.10). @@ -903,21 +903,24 @@ def __init__(self, axis: int = 1, n: Union[int, None] = None) -> None: self.axis = axis self.n = n - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: - img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. + img: array containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. Returns: - np.ndarray containing envelope of data in img along the specified axis. + array containing envelope of data in img along the specified axis. """ + img_t: torch.Tensor + img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + # add one to transform axis because a batch axis will be added at dimension 0 hilbert_transform = HilbertTransform(self.axis + 1, self.n) # convert to Tensor and add Batch axis expected by HilbertTransform - input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) - return np.abs(hilbert_transform(input_data).squeeze(0).numpy()) + out = torch.abs(hilbert_transform(img_t.unsqueeze(0))).squeeze(0) + return self.post_convert_data(out, orig_type, orig_device) class GaussianSmooth(TorchTransform): diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index bd2c04d73d..1912dcfc12 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -23,6 +23,7 @@ from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple from monai.utils.enums import DataObjects +from monai.utils.misc import dtype_numpy_to_torch, dtype_torch_to_numpy __all__ = [ "ThreadUnsafe", @@ -45,9 +46,11 @@ def convert_data_type( orig_device = data.device if isinstance(data, torch.Tensor) else None if orig_type is np.ndarray and output_type is torch.Tensor: - data = torch.Tensor(np.ascontiguousarray(data)) + dtype = dtype_numpy_to_torch(data.dtype) + data = torch.Tensor(np.ascontiguousarray(data)).to(dtype) elif orig_type is torch.Tensor and output_type is np.ndarray: - data = data.detach().cpu().numpy() + dtype = dtype_torch_to_numpy(data.dtype) + data = data.detach().cpu().numpy().astype(dtype) if isinstance(data, torch.Tensor) and device is not None: data.to(device) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index bd8e46d8b5..4b982fa80b 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -290,17 +290,17 @@ def _parse_var(s): _torch_to_np_dtype = { - torch.bool: bool, - torch.uint8: np.uint8, - torch.int8: np.int8, - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - torch.float16: np.float16, - torch.float32: np.float32, - torch.float64: np.float64, - torch.complex64: np.complex64, - torch.complex128: np.complex128, + torch.bool: np.dtype(bool), + torch.uint8: np.dtype(np.uint8), + torch.int8: np.dtype(np.int8), + torch.int16: np.dtype(np.int16), + torch.int32: np.dtype(np.int32), + torch.int64: np.dtype(np.int64), + torch.float16: np.dtype(np.float16), + torch.float32: np.dtype(np.float32), + torch.float64: np.dtype(np.float64), + torch.complex64: np.dtype(np.complex64), + torch.complex128: np.dtype(np.complex128), } _np_to_torch_dtype = {value: key for key, value in _torch_to_np_dtype.items()} diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 1e7bf3bca1..93b9de8fb4 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -10,6 +10,8 @@ # limitations under the License. import unittest +from functools import partial +from typing import Callable, List import numpy as np import torch @@ -18,7 +20,11 @@ from monai.transforms import CropForeground TESTS = [] -for p in (torch.Tensor, np.array): +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) + +for p in NDARRAYS: TESTS.append( [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py index ded0290de2..d0ba57ede1 100644 --- a/tests/test_detect_envelope.py +++ b/tests/test_detect_envelope.py @@ -10,8 +10,10 @@ # limitations under the License. import unittest +from typing import Callable, List import numpy as np +import torch from parameterized import parameterized from monai.transforms import DetectEnvelope @@ -110,6 +112,12 @@ TEST_CASE_INVALID_OBJ = [{}, "a string", "__call__"] # method expected to raise exception +from functools import partial + +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) + @SkipIfBeforePyTorchVersion((1, 7)) @SkipIfNoModule("torch.fft") @@ -125,8 +133,9 @@ class TestDetectEnvelope(unittest.TestCase): ] ) def test_value(self, arguments, image, expected_data, atol): - result = DetectEnvelope(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) + for p in NDARRAYS: + result = DetectEnvelope(**arguments)(p(image)) + np.testing.assert_allclose(result, expected_data, atol=atol) @parameterized.expand( [ diff --git a/tests/test_gaussian_sharpen.py b/tests/test_gaussian_sharpen.py index b4443cc3ed..0221d6e241 100644 --- a/tests/test_gaussian_sharpen.py +++ b/tests/test_gaussian_sharpen.py @@ -10,48 +10,81 @@ # limitations under the License. import unittest +from typing import Callable, List import numpy as np -from parameterized import parameterized import torch +from parameterized import parameterized from monai.transforms import GaussianSharpen TESTS = [] -for p in (torch.Tensor, np.array): - TESTS.append([ - {}, - p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[4.1081963, 3.4950666, 4.1081963], [3.7239995, 2.8491793, 3.7239995], [4.569839, 3.9529324, 4.569839]], - [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], - ] - ), - ]) +from functools import partial + +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) + +for p in NDARRAYS: + TESTS.append( + [ + {}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [ + [4.1081963, 3.4950666, 4.1081963], + [3.7239995, 2.8491793, 3.7239995], + [4.569839, 3.9529324, 4.569839], + ], + [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]], + ] + ), + ] + ) - TESTS.append([ - {"sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, - p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[4.513644, 4.869134, 4.513644], [8.467242, 9.4004135, 8.467242], [10.416813, 12.0653515, 10.416813]], - [[15.711488, 17.569994, 15.711488], [21.16811, 23.501041, 21.16811], [21.614658, 24.766209, 21.614658]], - ] - ), - ]) + TESTS.append( + [ + {"sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [ + [4.513644, 4.869134, 4.513644], + [8.467242, 9.4004135, 8.467242], + [10.416813, 12.0653515, 10.416813], + ], + [ + [15.711488, 17.569994, 15.711488], + [21.16811, 23.501041, 21.16811], + [21.614658, 24.766209, 21.614658], + ], + ] + ), + ] + ) - TESTS.append([ - {"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, - p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[3.3324685, 3.335536, 3.3324673], [7.7666636, 8.16056, 7.7666636], [12.662973, 14.317837, 12.6629715]], - [[15.329051, 16.57557, 15.329051], [19.41665, 20.40139, 19.416655], [24.659554, 27.557873, 24.659554]], - ] - ), - ]) + TESTS.append( + [ + {"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [ + [3.3324685, 3.335536, 3.3324673], + [7.7666636, 8.16056, 7.7666636], + [12.662973, 14.317837, 12.6629715], + ], + [ + [15.329051, 16.57557, 15.329051], + [19.41665, 20.40139, 19.416655], + [24.659554, 27.557873, 24.659554], + ], + ] + ), + ] + ) class TestGaussianSharpen(unittest.TestCase): diff --git a/tests/test_gaussian_smooth.py b/tests/test_gaussian_smooth.py index 896b0ee5bc..6305eae0ee 100644 --- a/tests/test_gaussian_smooth.py +++ b/tests/test_gaussian_smooth.py @@ -10,6 +10,7 @@ # limitations under the License. import unittest +from typing import Callable, List import numpy as np import torch @@ -18,7 +19,13 @@ from monai.transforms import GaussianSmooth TESTS = [] -for p in (torch.Tensor, np.array): +from functools import partial + +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) + +for p in NDARRAYS: TESTS.append( [ {"sigma": 1.5}, diff --git a/tests/test_inverse.py b/tests/test_inverse.py index f1ce314b01..5fd130eb52 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -60,6 +60,7 @@ allow_missing_keys_mode, convert_inverse_interp_mode, ) +from monai.transforms.utility.dictionary import ToTensord from monai.utils import first, get_seed, optional_import, set_determinism from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image, make_rand_affine @@ -470,7 +471,7 @@ ) ) -TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] +TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose([ToTensord(KEYS), Compose(t[3:])])) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py index 97d2ec98b9..fa286290a9 100644 --- a/tests/test_mask_intensity.py +++ b/tests/test_mask_intensity.py @@ -10,6 +10,7 @@ # limitations under the License. import unittest +from typing import Callable, List import numpy as np import torch @@ -18,30 +19,37 @@ from monai.transforms import MaskIntensity TEST_CASES = [] -for p in (torch.Tensor, np.array): - for q in (torch.Tensor, np.array): - for r in (torch.Tensor, np.array): - TEST_CASES.append( - [ - {"mask_data": p([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])}, # type: ignore - q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore - r([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), # type: ignore - ] - ) - TEST_CASES.append( - [ - {"mask_data": p([[[0, 0, 0], [0, 5, 0], [0, 0, 0]]])}, # type: ignore - q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore - r([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), # type: ignore - ] - ) - TEST_CASES.append( - [ - {"mask_data": p([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])}, # type: ignore - q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore - r([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), # type: ignore - ] - ) + +from functools import partial + +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) + +for p in NDARRAYS: + for q in NDARRAYS: + + TEST_CASES.append( + [ + {"mask_data": p([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])}, # type: ignore + q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore + np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), + ] + ) + TEST_CASES.append( + [ + {"mask_data": p([[[0, 0, 0], [0, 5, 0], [0, 0, 0]]])}, # type: ignore + q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore + np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), + ] + ) + TEST_CASES.append( + [ + {"mask_data": p([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])}, # type: ignore + q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore + np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), + ] + ) class TestMaskIntensity(unittest.TestCase): diff --git a/tests/test_rand_gaussian_sharpen.py b/tests/test_rand_gaussian_sharpen.py index 6f690200a2..6bf5b4bb58 100644 --- a/tests/test_rand_gaussian_sharpen.py +++ b/tests/test_rand_gaussian_sharpen.py @@ -10,84 +10,127 @@ # limitations under the License. import unittest +from functools import partial +from typing import Callable, List import numpy as np -from parameterized import parameterized import torch +from parameterized import parameterized from monai.transforms import RandGaussianSharpen TESTS = [] -for p in (torch.Tensor, np.array): - TESTS.append([ - {"prob": 1.0}, - p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[5.2919216, 5.5854445, 5.29192], [11.3982, 12.62332, 11.398202], [14.870525, 17.323769, 14.870527]], - [[20.413757, 22.767355, 20.413757], [28.495504, 31.558315, 28.495499], [29.99236, 34.505676, 29.992361]], - ] - ), - ]) - TESTS.append([ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": 0.4, - "sigma2_y": 0.4, - "sigma2_z": 0.4, - "prob": 1.0, - }, - p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[4.1071496, 3.597953, 4.1071477], [10.062014, 9.825114, 10.0620165], [14.698058, 15.818766, 14.698058]], - [[18.211048, 18.16049, 18.211048], [25.155039, 24.56279, 25.155039], [28.801964, 30.381308, 28.801964]], - ] - ), - ]) +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) + +for p in NDARRAYS: + TESTS.append( + [ + {"prob": 1.0}, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [ + [5.2919216, 5.5854445, 5.29192], + [11.3982, 12.62332, 11.398202], + [14.870525, 17.323769, 14.870527], + ], + [ + [20.413757, 22.767355, 20.413757], + [28.495504, 31.558315, 28.495499], + [29.99236, 34.505676, 29.992361], + ], + ] + ), + ] + ) + + TESTS.append( + [ + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": 0.4, + "sigma2_y": 0.4, + "sigma2_z": 0.4, + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [ + [4.1071496, 3.597953, 4.1071477], + [10.062014, 9.825114, 10.0620165], + [14.698058, 15.818766, 14.698058], + ], + [ + [18.211048, 18.16049, 18.211048], + [25.155039, 24.56279, 25.155039], + [28.801964, 30.381308, 28.801964], + ], + ] + ), + ] + ) - TESTS.append([ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "prob": 1.0, - }, - p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[4.81077, 4.4237204, 4.81077], [12.061236, 12.298177, 12.061236], [17.362553, 19.201174, 17.362553]], - [[21.440754, 22.142393, 21.440754], [30.15308, 30.745445, 30.153086], [33.99255, 36.919838, 33.99255]], - ] - ), - ]) + TESTS.append( + [ + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [ + [4.81077, 4.4237204, 4.81077], + [12.061236, 12.298177, 12.061236], + [17.362553, 19.201174, 17.362553], + ], + [ + [21.440754, 22.142393, 21.440754], + [30.15308, 30.745445, 30.153086], + [33.99255, 36.919838, 33.99255], + ], + ] + ), + ] + ) - TESTS.append([ - { - "sigma1_x": (0.5, 0.75), - "sigma1_y": (0.5, 0.75), - "sigma1_z": (0.5, 0.75), - "sigma2_x": (0.5, 0.75), - "sigma2_y": (0.5, 0.75), - "sigma2_z": (0.5, 0.75), - "approx": "scalespace", - "prob": 1.0, - }, - p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( - [ - [[4.430213, 3.2278745, 4.4302144], [10.325399, 8.507457, 10.325399], [17.494898, 16.5609, 17.494894]], - [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], - ] - ), - ]) + TESTS.append( + [ + { + "sigma1_x": (0.5, 0.75), + "sigma1_y": (0.5, 0.75), + "sigma1_z": (0.5, 0.75), + "sigma2_x": (0.5, 0.75), + "sigma2_y": (0.5, 0.75), + "sigma2_z": (0.5, 0.75), + "approx": "scalespace", + "prob": 1.0, + }, + p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + np.array( + [ + [ + [4.430213, 3.2278745, 4.4302144], + [10.325399, 8.507457, 10.325399], + [17.494898, 16.5609, 17.494894], + ], + [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]], + ] + ), + ] + ) class TestRandGaussianSharpen(unittest.TestCase): diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 478b32b7ed..313ff64353 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -10,6 +10,8 @@ # limitations under the License. import unittest +from functools import partial +from typing import Callable, List import numpy as np import torch @@ -17,11 +19,15 @@ from monai.transforms import RandRotate90 from tests.utils import NumpyImageTestCase2D +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) + class TestRandRotate90(NumpyImageTestCase2D): def test_default(self): rotate = RandRotate90() - for p in (torch.Tensor, np.array): + for p in NDARRAYS: rotate.set_random_state(123) rotated = rotate(p(self.imt[0])) expected = [] @@ -32,7 +38,7 @@ def test_default(self): def test_k(self): rotate = RandRotate90(max_k=2) - for p in (torch.Tensor, np.array): + for p in NDARRAYS: rotate.set_random_state(234) rotated = rotate(p(self.imt[0])) expected = [] @@ -43,7 +49,7 @@ def test_k(self): def test_spatial_axes(self): rotate = RandRotate90(spatial_axes=(0, 1)) - for p in (torch.Tensor, np.array): + for p in NDARRAYS: rotate.set_random_state(234) rotated = rotate(p(self.imt[0])) expected = [] @@ -54,7 +60,7 @@ def test_spatial_axes(self): def test_prob_k_spatial_axes(self): rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1)) - for p in (torch.Tensor, np.array): + for p in NDARRAYS: rotate.set_random_state(234) rotated = rotate(p(self.imt[0])) expected = [] diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 30c20820b4..ee77d9eae9 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -10,6 +10,8 @@ # limitations under the License. import unittest +from functools import partial +from typing import Callable, List import numpy as np import torch @@ -17,11 +19,15 @@ from monai.transforms import Rotate90 from tests.utils import NumpyImageTestCase2D +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) + class TestRotate90(NumpyImageTestCase2D): def test_rotate90_default(self): rotate = Rotate90() - for p in (torch.Tensor, np.array): + for p in NDARRAYS: rotated = rotate(p(self.imt[0])) expected = [] for channel in self.imt[0]: @@ -31,7 +37,7 @@ def test_rotate90_default(self): def test_k(self): rotate = Rotate90(k=2) - for p in (torch.Tensor, np.array): + for p in NDARRAYS: rotated = rotate(p(self.imt[0])) expected = [] for channel in self.imt[0]: @@ -41,7 +47,7 @@ def test_k(self): def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) - for p in (torch.Tensor, np.array): + for p in NDARRAYS: rotated = rotate(p(self.imt[0])) expected = [] for channel in self.imt[0]: @@ -51,7 +57,7 @@ def test_spatial_axes(self): def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) - for p in (torch.Tensor, np.array): + for p in NDARRAYS: rotated = rotate(p(self.imt[0])) expected = [] for channel in self.imt[0]: From 618c4b3ce3137f159b103350aa570a37e26387fc Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 3 Jun 2021 13:11:20 +0100 Subject: [PATCH 028/176] SavitzkyGolaySmooth Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 15 +++++++++------ tests/test_savitzky_golay_smooth.py | 12 ++++++++++-- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index f27f5c87b2..4817214ce9 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -842,7 +842,7 @@ def __call__(self, img: DataObjects.Images, mask_data: Optional[DataObjects.Imag return img * mask_data_ -class SavitzkyGolaySmooth(Transform): +class SavitzkyGolaySmooth(TorchTransform): """ Smooth the input data along the given axis using a Savitzky-Golay filter. @@ -864,20 +864,23 @@ def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = "z self.axis = axis self.mode = mode - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: - img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. + img: array containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. Returns: - np.ndarray containing smoothed result. + array containing smoothed result. """ + img_t: torch.Tensor + img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + # add one to transform axis because a batch axis will be added at dimension 0 savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) # convert to Tensor and add Batch axis expected by HilbertTransform - input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) - return savgol_filter(input_data).squeeze(0).numpy() + out = savgol_filter(img_t.unsqueeze(0)).squeeze(0) + return self.post_convert_data(out, orig_type, orig_device) class DetectEnvelope(TorchTransform): diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index 63dcce1b05..10553179d3 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -9,10 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, List import unittest import numpy as np from parameterized import parameterized +import torch from monai.transforms import SavitzkyGolaySmooth @@ -63,8 +65,14 @@ def test_value(self, arguments, image, expected_data, atol): np.testing.assert_allclose(result, expected_data, atol=atol) +from functools import partial +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) + class TestSavitzkyGolaySmoothREP(unittest.TestCase): @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolaySmooth(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) + for p in NDARRAYS: + result = SavitzkyGolaySmooth(**arguments)(p(image)) + np.testing.assert_allclose(result, expected_data, atol=atol) From 53b381a88464a4b76d76dfd1dc1ffcc7f417e30f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 3 Jun 2021 13:26:34 +0100 Subject: [PATCH 029/176] scale intensity Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 12 +++++++---- monai/transforms/utils.py | 16 +++++++++++---- tests/test_savitzky_golay_smooth.py | 6 ++++-- tests/test_scale_intensity.py | 31 +++++++++++++++++++---------- 4 files changed, 44 insertions(+), 21 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 4817214ce9..dde26b9c85 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -334,7 +334,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return shifter(img) -class ScaleIntensity(Transform): +class ScaleIntensity(TorchOrNumpyTransform): """ Scale the intensity of input image to the given value range (minv, maxv). If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``. @@ -354,7 +354,7 @@ def __init__( self.maxv = maxv self.factor = factor - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. @@ -363,9 +363,13 @@ def __call__(self, img: np.ndarray) -> np.ndarray: """ if self.minv is not None and self.maxv is not None: - return np.asarray(rescale_array(img, self.minv, self.maxv, img.dtype)) + return rescale_array(img, self.minv, self.maxv, img.dtype) if self.factor is not None: - return np.asarray(img * (1 + self.factor), dtype=img.dtype) + out = img * (1 + self.factor) + if isinstance(out, torch.Tensor): + return out.to(img.dtype) # type: ignore + else: + return out.astype(img.dtype) # type: ignore raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.") diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e45bd25bcd..52357e15cd 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -10,6 +10,7 @@ # limitations under the License. import itertools +from monai.utils.misc import dtype_numpy_to_torch, dtype_torch_to_numpy import random import warnings from contextlib import contextmanager @@ -117,15 +118,22 @@ def zero_margins(img: np.ndarray, margin: int) -> bool: return not np.any(img[:, :margin, :]) and not np.any(img[:, -margin:, :]) -def rescale_array(arr: np.ndarray, minv: float = 0.0, maxv: float = 1.0, dtype: DtypeLike = np.float32): +def rescale_array(arr: DataObjects.Images, minv: float = 0.0, maxv: float = 1.0, dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32) -> DataObjects.Images: """ Rescale the values of numpy array `arr` to be from `minv` to `maxv`. """ if dtype is not None: - arr = arr.astype(dtype) + if isinstance(arr, torch.Tensor): + if isinstance(dtype, np.dtype): + dtype = dtype_numpy_to_torch(dtype) + arr = arr.to(dtype) # type: ignore + else: + if isinstance(dtype, torch.dtype): + dtype = dtype_torch_to_numpy(dtype) + arr = arr.astype(dtype) # type: ignore - mina = np.min(arr) - maxa = np.max(arr) + mina = arr.min() + maxa = arr.max() if mina == maxa: return arr * minv diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index 10553179d3..2387bd83dc 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -9,12 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List import unittest +from typing import Callable, List import numpy as np -from parameterized import parameterized import torch +from parameterized import parameterized from monai.transforms import SavitzkyGolaySmooth @@ -66,10 +66,12 @@ def test_value(self, arguments, image, expected_data, atol): from functools import partial + NDARRAYS: List[Callable] = [np.array, torch.Tensor] if torch.cuda.is_available(): NDARRAYS.append(partial(torch.Tensor, device="cuda")) + class TestSavitzkyGolaySmoothREP(unittest.TestCase): @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) def test_value(self, arguments, image, expected_data, atol): diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index 61e89191fd..f8247b9221 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -9,29 +9,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, List import unittest import numpy as np +import torch from monai.transforms import ScaleIntensity from tests.utils import NumpyImageTestCase2D +from functools import partial + +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) class TestScaleIntensity(NumpyImageTestCase2D): def test_range_scale(self): - scaler = ScaleIntensity(minv=1.0, maxv=2.0) - result = scaler(self.imt) - mina = np.min(self.imt) - maxa = np.max(self.imt) - norm = (self.imt - mina) / (maxa - mina) - expected = (norm * (2.0 - 1.0)) + 1.0 - np.testing.assert_allclose(result, expected) + for p in NDARRAYS: + scaler = ScaleIntensity(minv=1.0, maxv=2.0) + result = scaler(p(self.imt)) + mina = self.imt.min() + maxa = self.imt.max() + norm = (self.imt - mina) / (maxa - mina) + expected = (norm * (2.0 - 1.0)) + 1.0 + np.testing.assert_allclose(result, expected) def test_factor_scale(self): - scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1) - result = scaler(self.imt) - expected = (self.imt * (1 + 0.1)).astype(np.float32) - np.testing.assert_allclose(result, expected) + for p in NDARRAYS: + scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1) + result = scaler(p(self.imt)) + expected = (self.imt * (1 + 0.1)).astype(np.float32) + np.testing.assert_allclose(result, expected) if __name__ == "__main__": From db2ab8e9e5ea163189409b78e2259676d37bf488 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 3 Jun 2021 13:28:31 +0100 Subject: [PATCH 030/176] rand scale intensity Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 4 ++-- tests/test_rand_scale_intensity.py | 20 ++++++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index dde26b9c85..5346c3b162 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -373,7 +373,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.") -class RandScaleIntensity(RandomizableTransform): +class RandScaleIntensity(TorchOrNumpyTransform, RandomizableTransform): """ Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor` is randomly picked. @@ -400,7 +400,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index 2126301758..cd48d1b6b2 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -9,22 +9,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, List import unittest import numpy as np +import torch from monai.transforms import RandScaleIntensity from tests.utils import NumpyImageTestCase2D +from functools import partial + +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) class TestRandScaleIntensity(NumpyImageTestCase2D): def test_value(self): - scaler = RandScaleIntensity(factors=0.5, prob=1.0) - scaler.set_random_state(seed=0) - result = scaler(self.imt) - np.random.seed(0) - expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32) - np.testing.assert_allclose(result, expected) + for p in NDARRAYS: + scaler = RandScaleIntensity(factors=0.5, prob=1.0) + scaler.set_random_state(seed=0) + result = scaler(p(self.imt)) + np.random.seed(0) + expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32) + np.testing.assert_allclose(result, expected) if __name__ == "__main__": From 0239a066ad9c17bf5f200fe2ddacb0555b3f47f3 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 3 Jun 2021 13:32:37 +0100 Subject: [PATCH 031/176] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/synthetic.py | 4 ++-- monai/transforms/utils.py | 9 +++++++-- tests/test_rand_scale_intensity.py | 4 ++-- tests/test_scale_intensity.py | 4 ++-- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py index 90cbe13c2d..32627a12e8 100644 --- a/monai/data/synthetic.py +++ b/monai/data/synthetic.py @@ -65,7 +65,7 @@ def create_test_image_2d( labels = np.ceil(image).astype(np.int32) norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape) - noisyimage = rescale_array(np.maximum(image, norm)) + noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore if channel_dim is not None: if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 2)): @@ -129,7 +129,7 @@ def create_test_image_3d( labels = np.ceil(image).astype(np.int32) norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape) - noisyimage = rescale_array(np.maximum(image, norm)) + noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore if channel_dim is not None: if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 3)): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 52357e15cd..fda68b1f36 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -10,7 +10,6 @@ # limitations under the License. import itertools -from monai.utils.misc import dtype_numpy_to_torch, dtype_torch_to_numpy import random import warnings from contextlib import contextmanager @@ -36,6 +35,7 @@ optional_import, ) from monai.utils.enums import DataObjects +from monai.utils.misc import dtype_numpy_to_torch, dtype_torch_to_numpy measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -118,7 +118,12 @@ def zero_margins(img: np.ndarray, margin: int) -> bool: return not np.any(img[:, :margin, :]) and not np.any(img[:, -margin:, :]) -def rescale_array(arr: DataObjects.Images, minv: float = 0.0, maxv: float = 1.0, dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32) -> DataObjects.Images: +def rescale_array( + arr: DataObjects.Images, + minv: float = 0.0, + maxv: float = 1.0, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, +) -> DataObjects.Images: """ Rescale the values of numpy array `arr` to be from `minv` to `maxv`. """ diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index cd48d1b6b2..22965e6e20 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -9,15 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List import unittest +from functools import partial +from typing import Callable, List import numpy as np import torch from monai.transforms import RandScaleIntensity from tests.utils import NumpyImageTestCase2D -from functools import partial NDARRAYS: List[Callable] = [np.array, torch.Tensor] if torch.cuda.is_available(): diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index f8247b9221..c7d67c7077 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -9,15 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List import unittest +from functools import partial +from typing import Callable, List import numpy as np import torch from monai.transforms import ScaleIntensity from tests.utils import NumpyImageTestCase2D -from functools import partial NDARRAYS: List[Callable] = [np.array, torch.Tensor] if torch.cuda.is_available(): From c3880df5f73c05f3ae1e23552299de6ff41bb461 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 3 Jun 2021 15:02:05 +0100 Subject: [PATCH 032/176] test smart cache dataset returns torch Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_smartcachedataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index 1499854c56..5e14e7c5d2 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -17,6 +17,7 @@ import nibabel as nib import numpy as np from parameterized import parameterized +import torch from monai.data import SmartCacheDataset from monai.transforms import Compose, LoadImaged @@ -65,7 +66,7 @@ def test_shape(self, replace_rate, num_replace_workers, transform): for _ in range(3): dataset.update_cache() self.assertIsNotNone(dataset[15]) - if isinstance(dataset[15]["image"], np.ndarray): + if isinstance(dataset[15]["image"], (np.ndarray, torch.Tensor)): np.testing.assert_allclose(dataset[15]["image"], dataset[15]["label"]) else: self.assertIsInstance(dataset[15]["image"], str) From 6a2688903254778cc4c41852c1bc687cf11ef502 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 3 Jun 2021 15:08:30 +0100 Subject: [PATCH 033/176] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_smartcachedataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index 5e14e7c5d2..d1c717df21 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -16,8 +16,8 @@ import nibabel as nib import numpy as np -from parameterized import parameterized import torch +from parameterized import parameterized from monai.data import SmartCacheDataset from monai.transforms import Compose, LoadImaged From 5513e5d4807429b824b3573730f392e324d8e62b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 14 Jun 2021 11:54:37 +0100 Subject: [PATCH 034/176] post merge changes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/config/type_definitions.py | 1 + monai/transforms/croppad/dictionary.py | 6 +++--- monai/transforms/utility/array.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py index b3b86231de..375ae460b2 100644 --- a/monai/config/type_definitions.py +++ b/monai/config/type_definitions.py @@ -12,6 +12,7 @@ from typing import Collection, Hashable, Iterable, Sequence, TypeVar, Union import numpy as np +import torch __all__ = ["KeysCollection", "IndexSelection", "DtypeLike", "NdarrayTensor", "TensorOrList"] diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index bdd38d5dfe..f44d12e95f 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -20,7 +20,7 @@ from enum import Enum from itertools import chain from math import ceil, floor -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Hashable, List, Optional, Sequence, Tuple, Union import numpy as np @@ -857,7 +857,7 @@ def __call__(self, data: DataObjects.Mapping) -> List[DataObjects.Dict]: _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(data) for _ in range(self.num_samples)] + results: List[Dict[Hashable, Any]] = [dict(data) for _ in range(self.num_samples)] # fill in the extra keys with unmodified data for i in range(self.num_samples): for key in set(data.keys()).difference(set(self.keys)): @@ -1029,7 +1029,7 @@ def get_as_np(x) -> np.ndarray: raise AssertionError # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(data) for _ in range(self.num_samples)] + results: List[Dict[Hashable, Any]] = [dict(data) for _ in range(self.num_samples)] for i, center in enumerate(self.centers): # fill in the extra keys with unmodified data diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index dd386fd31e..44795ee5a9 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -17,7 +17,7 @@ import sys import time import warnings -from typing import Callable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch From 03b825bae672bb1b55b75a4d463e1361996f8044 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 14 Jun 2021 12:29:12 +0100 Subject: [PATCH 035/176] StdShiftIntensity and RandStdShiftIntensity Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 38 ++++++++++------ monai/transforms/transform.py | 28 ++++++++---- monai/utils/misc.py | 1 + tests/test_rand_std_shift_intensity.py | 23 +++++++--- tests/test_std_shift_intensity.py | 63 ++++++++++++++++---------- 5 files changed, 100 insertions(+), 53 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 5346c3b162..b826c22abc 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -14,7 +14,9 @@ """ from collections.abc import Iterable -from typing import Any, List, Optional, Sequence, Tuple, Union +from copy import deepcopy +from functools import partial +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from warnings import warn import numpy as np @@ -237,7 +239,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return shifter(img) -class StdShiftIntensity(Transform): +class StdShiftIntensity(TorchOrNumpyTransform): """ Shift intensity for the image with a factor and the standard deviation of the image by: ``v = v + factor * std(v)``. @@ -260,28 +262,38 @@ def __init__( self.channel_wise = channel_wise self.dtype = dtype - def _stdshift(self, img: np.ndarray) -> np.ndarray: - slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=bool) - if not np.any(slices): - return img - offset = self.factor * np.std(img[slices]) - img[slices] = img[slices] + offset + def _stdshift(self, img: DataObjects.Images) -> DataObjects.Images: + ones: Callable + std: Callable + if isinstance(img, torch.Tensor): + ones = torch.ones + std = partial(torch.std, unbiased=False) + else: + ones = np.ones + std = np.std + + slices = (img != 0) if self.nonzero else ones(img.shape, dtype=bool) + if slices.any(): + out = deepcopy(img) + offset = self.factor * std(out[slices]) + out[slices] = out[slices] + offset + return out return img - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - img = img.astype(self.dtype) + img, *_ = convert_data_type(img, dtype=self.dtype) if self.channel_wise: for i, d in enumerate(img): - img[i] = self._stdshift(d) + img[i] = self._stdshift(d) # type: ignore else: img = self._stdshift(img) return img -class RandStdShiftIntensity(RandomizableTransform): +class RandStdShiftIntensity(TorchOrNumpyTransform, RandomizableTransform): """ Shift intensity for the image with a factor and the standard deviation of the image by: ``v = v + factor * std(v)`` where the `factor` is randomly picked. @@ -321,7 +333,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) super().randomize(None) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 1912dcfc12..e6981d9f4a 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -21,6 +21,7 @@ from monai import transforms from monai.config import KeysCollection +from monai.config.type_definitions import DtypeLike from monai.utils import MAX_SEED, ensure_tuple from monai.utils.enums import DataObjects from monai.utils.misc import dtype_numpy_to_torch, dtype_torch_to_numpy @@ -38,19 +39,30 @@ def convert_data_type( - data, output_type: type, device: Optional[torch.device] = None + data: DataObjects.Images, + output_type: Optional[type] = None, + device: Optional[torch.device] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, ) -> Tuple[DataObjects.Images, type, Optional[torch.device]]: """Convert to torch.Tensor/np.ndarray.""" orig_type = type(data) - assert orig_type in (torch.Tensor, np.ndarray), f"Expected types {(torch.Tensor, np.ndarray)}, got {orig_type}" orig_device = data.device if isinstance(data, torch.Tensor) else None - if orig_type is np.ndarray and output_type is torch.Tensor: - dtype = dtype_numpy_to_torch(data.dtype) - data = torch.Tensor(np.ascontiguousarray(data)).to(dtype) - elif orig_type is torch.Tensor and output_type is np.ndarray: - dtype = dtype_torch_to_numpy(data.dtype) - data = data.detach().cpu().numpy().astype(dtype) + output_type = output_type or orig_type + dtype = dtype or data.dtype + + if output_type is torch.Tensor: + if orig_type is np.ndarray: + data = torch.Tensor(np.ascontiguousarray(data)) + if dtype != data.dtype: + dtype = dtype_numpy_to_torch(dtype) if not isinstance(dtype, torch.dtype) else dtype + data = data.to(dtype) # type: ignore + elif output_type is np.ndarray: + if orig_type is torch.Tensor: + data = data.detach().cpu().numpy() # type: ignore + if dtype != data.dtype: + dtype = dtype_torch_to_numpy(dtype) if isinstance(dtype, torch.dtype) else dtype + data = data.astype(dtype) # type: ignore if isinstance(data, torch.Tensor) and device is not None: data.to(device) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 4b75036183..084108132f 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -316,6 +316,7 @@ def dtype_torch_to_numpy(dtype): def dtype_numpy_to_torch(dtype): """Convert a numpy dtype to its torch equivalent.""" + dtype = np.dtype(dtype) if type(dtype) == type else dtype return _np_to_torch_dtype[dtype] diff --git a/tests/test_rand_std_shift_intensity.py b/tests/test_rand_std_shift_intensity.py index 9aff50ab66..c5004c3060 100644 --- a/tests/test_rand_std_shift_intensity.py +++ b/tests/test_rand_std_shift_intensity.py @@ -10,22 +10,31 @@ # limitations under the License. import unittest +from functools import partial +from typing import Callable, List import numpy as np +import torch from monai.transforms import RandStdShiftIntensity from tests.utils import NumpyImageTestCase2D +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) + class TestRandStdShiftIntensity(NumpyImageTestCase2D): def test_value(self): - shifter = RandStdShiftIntensity(factors=1.0, prob=1.0) - shifter.set_random_state(seed=0) - result = shifter(self.imt) - np.random.seed(0) - factor = np.random.uniform(low=-1.0, high=1.0) - expected = self.imt + factor * np.std(self.imt) - np.testing.assert_allclose(result, expected, rtol=1e-5) + for p in NDARRAYS: + shifter = RandStdShiftIntensity(factors=1.0, prob=1.0) + shifter.set_random_state(seed=0) + result = shifter(p(self.imt)) + np.random.seed(0) + factor = np.random.uniform(low=-1.0, high=1.0) + offset = factor * np.std(self.imt) + expected = self.imt + offset + np.testing.assert_allclose(result, expected, rtol=1e-5) if __name__ == "__main__": diff --git a/tests/test_std_shift_intensity.py b/tests/test_std_shift_intensity.py index f317330435..81b50cfb63 100644 --- a/tests/test_std_shift_intensity.py +++ b/tests/test_std_shift_intensity.py @@ -10,47 +10,60 @@ # limitations under the License. import unittest +from functools import partial +from typing import Callable, List import numpy as np +import torch from monai.transforms import ShiftIntensity, StdShiftIntensity from tests.utils import NumpyImageTestCase2D +NDARRAYS: List[Callable] = [np.array, torch.Tensor] +if torch.cuda.is_available(): + NDARRAYS.append(partial(torch.Tensor, device="cuda")) + class TestStdShiftIntensity(NumpyImageTestCase2D): def test_value(self): - factor = np.random.rand() - offset = np.std(self.imt) * factor - shifter = ShiftIntensity(offset=offset) - expected = shifter(self.imt) - std_shifter = StdShiftIntensity(factor=factor) - result = std_shifter(self.imt) - np.testing.assert_allclose(result, expected, rtol=1e-5) + for p in NDARRAYS: + imt = p(self.imt) + factor = np.random.rand() + offset = np.std(self.imt) * factor + shifter = ShiftIntensity(offset=offset) + expected = shifter(imt) + std_shifter = StdShiftIntensity(factor=factor) + result = std_shifter(imt) + np.testing.assert_allclose(result, expected, rtol=1e-5) def test_zerostd(self): image = np.ones([2, 3, 3]) - for nonzero in [True, False]: - for channel_wise in [True, False]: - factor = np.random.rand() - std_shifter = StdShiftIntensity(factor=factor, nonzero=nonzero, channel_wise=channel_wise) - result = std_shifter(image) - np.testing.assert_allclose(result, image, rtol=1e-5) + for p in NDARRAYS: + image = p(image) + for nonzero in [True, False]: + for channel_wise in [True, False]: + factor = np.random.rand() + std_shifter = StdShiftIntensity(factor=factor, nonzero=nonzero, channel_wise=channel_wise) + result = std_shifter(image) + np.testing.assert_allclose(result, image, rtol=1e-5) def test_nonzero(self): - image = np.asarray([[4.0, 0.0, 2.0], [0, 2, 4]]) # std = 1 - factor = np.random.rand() - std_shifter = StdShiftIntensity(factor=factor, nonzero=True) - result = std_shifter(image) - expected = np.asarray([[4 + factor, 0, 2 + factor], [0, 2 + factor, 4 + factor]]) - np.testing.assert_allclose(result, expected, rtol=1e-5) + for p in NDARRAYS: + image = p(np.asarray([[4.0, 0.0, 2.0], [0, 2, 4]])) # std = 1 + factor = np.random.rand() + std_shifter = StdShiftIntensity(factor=factor, nonzero=True) + result = std_shifter(image) + expected = np.asarray([[4 + factor, 0, 2 + factor], [0, 2 + factor, 4 + factor]]) + np.testing.assert_allclose(result, expected, rtol=1e-5) def test_channel_wise(self): - image = np.stack((np.asarray([1.0, 2.0]), np.asarray([1.0, 1.0]))) # std: 0.5, 0 - factor = np.random.rand() - std_shifter = StdShiftIntensity(factor=factor, channel_wise=True) - result = std_shifter(image) - expected = np.stack((np.asarray([1 + 0.5 * factor, 2 + 0.5 * factor]), np.asarray([1, 1]))) - np.testing.assert_allclose(result, expected, rtol=1e-5) + for p in NDARRAYS: + image = p(np.stack((np.asarray([1.0, 2.0]), np.asarray([1.0, 1.0])))) # std: 0.5, 0 + factor = np.random.rand() + std_shifter = StdShiftIntensity(factor=factor, channel_wise=True) + result = std_shifter(image) + expected = np.stack((np.asarray([1 + 0.5 * factor, 2 + 0.5 * factor]), np.asarray([1, 1]))) + np.testing.assert_allclose(result, expected, rtol=1e-5) def test_dtype(self): trans_dtype = np.float32 From 9739ba50cf00aca2ec32028fba639dab4cee1824 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 14 Jun 2021 12:46:42 +0100 Subject: [PATCH 036/176] TEST_NDARRAYS reduce code duplication Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_crop_foreground.py | 9 ++------- tests/test_detect_envelope.py | 18 ++++++++---------- tests/test_gaussian_sharpen.py | 11 ++--------- tests/test_gaussian_smooth.py | 10 ++-------- tests/test_mask_intensity.py | 13 +++---------- tests/test_rand_gaussian_sharpen.py | 11 ++--------- tests/test_rand_rotate90.py | 17 +++++------------ tests/test_rand_scale_intensity.py | 11 ++--------- tests/test_rand_std_shift_intensity.py | 11 ++--------- tests/test_rotate90.py | 17 +++++------------ tests/test_savitzky_golay_smooth.py | 12 ++---------- tests/test_scale_intensity.py | 13 +++---------- tests/test_std_shift_intensity.py | 17 +++++------------ tests/utils.py | 9 ++++++++- 14 files changed, 51 insertions(+), 128 deletions(-) diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 93b9de8fb4..4c51a274a0 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -10,21 +10,16 @@ # limitations under the License. import unittest -from functools import partial -from typing import Callable, List import numpy as np -import torch from parameterized import parameterized from monai.transforms import CropForeground +from tests.utils import TEST_NDARRAYS TESTS = [] -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) -for p in NDARRAYS: +for p in TEST_NDARRAYS: TESTS.append( [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py index d0ba57ede1..654183f01a 100644 --- a/tests/test_detect_envelope.py +++ b/tests/test_detect_envelope.py @@ -10,15 +10,19 @@ # limitations under the License. import unittest -from typing import Callable, List import numpy as np -import torch from parameterized import parameterized from monai.transforms import DetectEnvelope from monai.utils import InvalidPyTorchVersionError, OptionalImportError -from tests.utils import SkipIfAtLeastPyTorchVersion, SkipIfBeforePyTorchVersion, SkipIfModule, SkipIfNoModule +from tests.utils import ( + TEST_NDARRAYS, + SkipIfAtLeastPyTorchVersion, + SkipIfBeforePyTorchVersion, + SkipIfModule, + SkipIfNoModule, +) n_samples = 500 hann_windowed_sine = np.sin(2 * np.pi * 10 * np.linspace(0, 1, n_samples)) * np.hanning(n_samples) @@ -112,12 +116,6 @@ TEST_CASE_INVALID_OBJ = [{}, "a string", "__call__"] # method expected to raise exception -from functools import partial - -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) - @SkipIfBeforePyTorchVersion((1, 7)) @SkipIfNoModule("torch.fft") @@ -133,7 +131,7 @@ class TestDetectEnvelope(unittest.TestCase): ] ) def test_value(self, arguments, image, expected_data, atol): - for p in NDARRAYS: + for p in TEST_NDARRAYS: result = DetectEnvelope(**arguments)(p(image)) np.testing.assert_allclose(result, expected_data, atol=atol) diff --git a/tests/test_gaussian_sharpen.py b/tests/test_gaussian_sharpen.py index 0221d6e241..cdd4147c41 100644 --- a/tests/test_gaussian_sharpen.py +++ b/tests/test_gaussian_sharpen.py @@ -10,23 +10,16 @@ # limitations under the License. import unittest -from typing import Callable, List import numpy as np -import torch from parameterized import parameterized from monai.transforms import GaussianSharpen +from tests.utils import TEST_NDARRAYS TESTS = [] -from functools import partial - -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) - -for p in NDARRAYS: +for p in TEST_NDARRAYS: TESTS.append( [ {}, diff --git a/tests/test_gaussian_smooth.py b/tests/test_gaussian_smooth.py index 6305eae0ee..f7f69a5d6a 100644 --- a/tests/test_gaussian_smooth.py +++ b/tests/test_gaussian_smooth.py @@ -10,22 +10,16 @@ # limitations under the License. import unittest -from typing import Callable, List import numpy as np -import torch from parameterized import parameterized from monai.transforms import GaussianSmooth +from tests.utils import TEST_NDARRAYS TESTS = [] -from functools import partial -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) - -for p in NDARRAYS: +for p in TEST_NDARRAYS: TESTS.append( [ {"sigma": 1.5}, diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py index fa286290a9..5a33ef9171 100644 --- a/tests/test_mask_intensity.py +++ b/tests/test_mask_intensity.py @@ -10,24 +10,17 @@ # limitations under the License. import unittest -from typing import Callable, List import numpy as np -import torch from parameterized import parameterized from monai.transforms import MaskIntensity +from tests.utils import TEST_NDARRAYS TEST_CASES = [] -from functools import partial - -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) - -for p in NDARRAYS: - for q in NDARRAYS: +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: TEST_CASES.append( [ diff --git a/tests/test_rand_gaussian_sharpen.py b/tests/test_rand_gaussian_sharpen.py index 6bf5b4bb58..beb21f2acc 100644 --- a/tests/test_rand_gaussian_sharpen.py +++ b/tests/test_rand_gaussian_sharpen.py @@ -10,23 +10,16 @@ # limitations under the License. import unittest -from functools import partial -from typing import Callable, List import numpy as np -import torch from parameterized import parameterized from monai.transforms import RandGaussianSharpen +from tests.utils import TEST_NDARRAYS TESTS = [] - -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) - -for p in NDARRAYS: +for p in TEST_NDARRAYS: TESTS.append( [ {"prob": 1.0}, diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 313ff64353..55036d38cb 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -10,24 +10,17 @@ # limitations under the License. import unittest -from functools import partial -from typing import Callable, List import numpy as np -import torch from monai.transforms import RandRotate90 -from tests.utils import NumpyImageTestCase2D - -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestRandRotate90(NumpyImageTestCase2D): def test_default(self): rotate = RandRotate90() - for p in NDARRAYS: + for p in TEST_NDARRAYS: rotate.set_random_state(123) rotated = rotate(p(self.imt[0])) expected = [] @@ -38,7 +31,7 @@ def test_default(self): def test_k(self): rotate = RandRotate90(max_k=2) - for p in NDARRAYS: + for p in TEST_NDARRAYS: rotate.set_random_state(234) rotated = rotate(p(self.imt[0])) expected = [] @@ -49,7 +42,7 @@ def test_k(self): def test_spatial_axes(self): rotate = RandRotate90(spatial_axes=(0, 1)) - for p in NDARRAYS: + for p in TEST_NDARRAYS: rotate.set_random_state(234) rotated = rotate(p(self.imt[0])) expected = [] @@ -60,7 +53,7 @@ def test_spatial_axes(self): def test_prob_k_spatial_axes(self): rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1)) - for p in NDARRAYS: + for p in TEST_NDARRAYS: rotate.set_random_state(234) rotated = rotate(p(self.imt[0])) expected = [] diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index 22965e6e20..22fa0f4ce8 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -10,23 +10,16 @@ # limitations under the License. import unittest -from functools import partial -from typing import Callable, List import numpy as np -import torch from monai.transforms import RandScaleIntensity -from tests.utils import NumpyImageTestCase2D - -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestRandScaleIntensity(NumpyImageTestCase2D): def test_value(self): - for p in NDARRAYS: + for p in TEST_NDARRAYS: scaler = RandScaleIntensity(factors=0.5, prob=1.0) scaler.set_random_state(seed=0) result = scaler(p(self.imt)) diff --git a/tests/test_rand_std_shift_intensity.py b/tests/test_rand_std_shift_intensity.py index c5004c3060..ae22596b67 100644 --- a/tests/test_rand_std_shift_intensity.py +++ b/tests/test_rand_std_shift_intensity.py @@ -10,23 +10,16 @@ # limitations under the License. import unittest -from functools import partial -from typing import Callable, List import numpy as np -import torch from monai.transforms import RandStdShiftIntensity -from tests.utils import NumpyImageTestCase2D - -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestRandStdShiftIntensity(NumpyImageTestCase2D): def test_value(self): - for p in NDARRAYS: + for p in TEST_NDARRAYS: shifter = RandStdShiftIntensity(factors=1.0, prob=1.0) shifter.set_random_state(seed=0) result = shifter(p(self.imt)) diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index ee77d9eae9..83959ba3a5 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -10,24 +10,17 @@ # limitations under the License. import unittest -from functools import partial -from typing import Callable, List import numpy as np -import torch from monai.transforms import Rotate90 -from tests.utils import NumpyImageTestCase2D - -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestRotate90(NumpyImageTestCase2D): def test_rotate90_default(self): rotate = Rotate90() - for p in NDARRAYS: + for p in TEST_NDARRAYS: rotated = rotate(p(self.imt[0])) expected = [] for channel in self.imt[0]: @@ -37,7 +30,7 @@ def test_rotate90_default(self): def test_k(self): rotate = Rotate90(k=2) - for p in NDARRAYS: + for p in TEST_NDARRAYS: rotated = rotate(p(self.imt[0])) expected = [] for channel in self.imt[0]: @@ -47,7 +40,7 @@ def test_k(self): def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) - for p in NDARRAYS: + for p in TEST_NDARRAYS: rotated = rotate(p(self.imt[0])) expected = [] for channel in self.imt[0]: @@ -57,7 +50,7 @@ def test_spatial_axes(self): def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) - for p in NDARRAYS: + for p in TEST_NDARRAYS: rotated = rotate(p(self.imt[0])) expected = [] for channel in self.imt[0]: diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index 2387bd83dc..bd3be49f11 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -10,13 +10,12 @@ # limitations under the License. import unittest -from typing import Callable, List import numpy as np -import torch from parameterized import parameterized from monai.transforms import SavitzkyGolaySmooth +from tests.utils import TEST_NDARRAYS # Zero-padding trivial tests @@ -65,16 +64,9 @@ def test_value(self, arguments, image, expected_data, atol): np.testing.assert_allclose(result, expected_data, atol=atol) -from functools import partial - -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) - - class TestSavitzkyGolaySmoothREP(unittest.TestCase): @parameterized.expand([TEST_CASE_SINGLE_VALUE_REP]) def test_value(self, arguments, image, expected_data, atol): - for p in NDARRAYS: + for p in TEST_NDARRAYS: result = SavitzkyGolaySmooth(**arguments)(p(image)) np.testing.assert_allclose(result, expected_data, atol=atol) diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index c7d67c7077..89e065cb6b 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -10,23 +10,16 @@ # limitations under the License. import unittest -from functools import partial -from typing import Callable, List import numpy as np -import torch from monai.transforms import ScaleIntensity -from tests.utils import NumpyImageTestCase2D - -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestScaleIntensity(NumpyImageTestCase2D): def test_range_scale(self): - for p in NDARRAYS: + for p in TEST_NDARRAYS: scaler = ScaleIntensity(minv=1.0, maxv=2.0) result = scaler(p(self.imt)) mina = self.imt.min() @@ -36,7 +29,7 @@ def test_range_scale(self): np.testing.assert_allclose(result, expected) def test_factor_scale(self): - for p in NDARRAYS: + for p in TEST_NDARRAYS: scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1) result = scaler(p(self.imt)) expected = (self.imt * (1 + 0.1)).astype(np.float32) diff --git a/tests/test_std_shift_intensity.py b/tests/test_std_shift_intensity.py index 81b50cfb63..1d385a54a4 100644 --- a/tests/test_std_shift_intensity.py +++ b/tests/test_std_shift_intensity.py @@ -10,23 +10,16 @@ # limitations under the License. import unittest -from functools import partial -from typing import Callable, List import numpy as np -import torch from monai.transforms import ShiftIntensity, StdShiftIntensity -from tests.utils import NumpyImageTestCase2D - -NDARRAYS: List[Callable] = [np.array, torch.Tensor] -if torch.cuda.is_available(): - NDARRAYS.append(partial(torch.Tensor, device="cuda")) +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestStdShiftIntensity(NumpyImageTestCase2D): def test_value(self): - for p in NDARRAYS: + for p in TEST_NDARRAYS: imt = p(self.imt) factor = np.random.rand() offset = np.std(self.imt) * factor @@ -38,7 +31,7 @@ def test_value(self): def test_zerostd(self): image = np.ones([2, 3, 3]) - for p in NDARRAYS: + for p in TEST_NDARRAYS: image = p(image) for nonzero in [True, False]: for channel_wise in [True, False]: @@ -48,7 +41,7 @@ def test_zerostd(self): np.testing.assert_allclose(result, image, rtol=1e-5) def test_nonzero(self): - for p in NDARRAYS: + for p in TEST_NDARRAYS: image = p(np.asarray([[4.0, 0.0, 2.0], [0, 2, 4]])) # std = 1 factor = np.random.rand() std_shifter = StdShiftIntensity(factor=factor, nonzero=True) @@ -57,7 +50,7 @@ def test_nonzero(self): np.testing.assert_allclose(result, expected, rtol=1e-5) def test_channel_wise(self): - for p in NDARRAYS: + for p in TEST_NDARRAYS: image = p(np.stack((np.asarray([1.0, 2.0]), np.asarray([1.0, 1.0])))) # std: 0.5, 0 factor = np.random.rand() std_shifter = StdShiftIntensity(factor=factor, channel_wise=True) diff --git a/tests/utils.py b/tests/utils.py index 2edc51d3e8..4b1d423b08 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,9 +20,10 @@ import traceback import unittest import warnings +from functools import partial from io import BytesIO from subprocess import PIPE, Popen -from typing import Optional +from typing import Callable, Optional, Tuple from urllib.error import ContentTooShortError, HTTPError, URLError import numpy as np @@ -583,5 +584,11 @@ def query_memory(n=2): return ",".join([f"{int(x)}" for x in ids]) +TEST_NDARRAYS: Tuple[Callable] = (np.array, torch.Tensor) # type: ignore +if torch.cuda.is_available(): + gpu_tensor: Callable = partial(torch.Tensor, device="cuda") + TEST_NDARRAYS = TEST_NDARRAYS + (gpu_tensor,) # type: ignore + + if __name__ == "__main__": print(query_memory()) From a270fba3db5d015f5b71c62ac1aad345c89bfdf6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 14 Jun 2021 16:25:17 +0100 Subject: [PATCH 037/176] dtype_convert and testing Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 11 ++---- monai/transforms/intensity/dictionary.py | 5 +-- monai/transforms/transform.py | 6 ++-- monai/transforms/utils.py | 12 ++----- monai/utils/__init__.py | 1 + monai/utils/misc.py | 21 +++++++++++ tests/test_convert_data_type.py | 45 ++++++++++++++++++++++++ tests/test_dtype_convert.py | 36 +++++++++++++++++++ 8 files changed, 113 insertions(+), 24 deletions(-) create mode 100644 tests/test_convert_data_type.py create mode 100644 tests/test_dtype_convert.py diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index b826c22abc..47bd331755 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -32,14 +32,9 @@ convert_data_type, ) from monai.transforms.utils import rescale_array -from monai.utils import ( - PT_BEFORE_1_7, - InvalidPyTorchVersionError, - dtype_torch_to_numpy, - ensure_tuple_rep, - ensure_tuple_size, -) +from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, ensure_tuple_rep, ensure_tuple_size from monai.utils.enums import DataObjects +from monai.utils.misc import dtype_convert __all__ = [ "RandGaussianNoise", @@ -99,7 +94,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: raise AssertionError if not self._do_transform: return img - dtype = dtype_torch_to_numpy(img.dtype) if isinstance(img, torch.Tensor) else img.dtype + dtype = dtype_convert(img.dtype, np.array) return img + self._noise.astype(dtype) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 1deedfc394..26e5d3a86c 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -39,8 +39,9 @@ ThresholdIntensity, ) from monai.transforms.transform import MapTransform, RandomizableTransform -from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size +from monai.utils import ensure_tuple_rep, ensure_tuple_size from monai.utils.enums import DataObjects +from monai.utils.misc import dtype_convert __all__ = [ "RandGaussianNoised", @@ -159,7 +160,7 @@ def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: if not self._do_transform: return d for key, noise in self.key_iterator(d, self._noise): - dtype = dtype_torch_to_numpy(d[key].dtype) if isinstance(d[key], torch.Tensor) else d[key].dtype + dtype = dtype_convert(d[key].dtype, np.array) d[key] = d[key] + noise.astype(dtype) return d diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index e6981d9f4a..ffe98c1b5b 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -24,7 +24,7 @@ from monai.config.type_definitions import DtypeLike from monai.utils import MAX_SEED, ensure_tuple from monai.utils.enums import DataObjects -from monai.utils.misc import dtype_numpy_to_torch, dtype_torch_to_numpy +from monai.utils.misc import dtype_convert __all__ = [ "ThreadUnsafe", @@ -49,19 +49,17 @@ def convert_data_type( orig_device = data.device if isinstance(data, torch.Tensor) else None output_type = output_type or orig_type - dtype = dtype or data.dtype + dtype = dtype_convert(dtype or data.dtype, output_type) if output_type is torch.Tensor: if orig_type is np.ndarray: data = torch.Tensor(np.ascontiguousarray(data)) if dtype != data.dtype: - dtype = dtype_numpy_to_torch(dtype) if not isinstance(dtype, torch.dtype) else dtype data = data.to(dtype) # type: ignore elif output_type is np.ndarray: if orig_type is torch.Tensor: data = data.detach().cpu().numpy() # type: ignore if dtype != data.dtype: - dtype = dtype_torch_to_numpy(dtype) if isinstance(dtype, torch.dtype) else dtype data = data.astype(dtype) # type: ignore if isinstance(data, torch.Tensor) and device is not None: diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index fda68b1f36..a1d47e0722 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -21,7 +21,7 @@ from monai.config import DtypeLike, IndexSelection from monai.networks.layers import GaussianFilter from monai.transforms.compose import Compose -from monai.transforms.transform import MapTransform +from monai.transforms.transform import MapTransform, convert_data_type from monai.utils import ( GridSampleMode, InterpolateMode, @@ -35,7 +35,6 @@ optional_import, ) from monai.utils.enums import DataObjects -from monai.utils.misc import dtype_numpy_to_torch, dtype_torch_to_numpy measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -128,14 +127,7 @@ def rescale_array( Rescale the values of numpy array `arr` to be from `minv` to `maxv`. """ if dtype is not None: - if isinstance(arr, torch.Tensor): - if isinstance(dtype, np.dtype): - dtype = dtype_numpy_to_torch(dtype) - arr = arr.to(dtype) # type: ignore - else: - if isinstance(dtype, torch.dtype): - dtype = dtype_torch_to_numpy(dtype) - arr = arr.astype(dtype) # type: ignore + arr, *_ = convert_data_type(arr, dtype=dtype) mina = arr.min() maxa = arr.max() diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 21b3aa1521..1418ef9848 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -40,6 +40,7 @@ MAX_SEED, ImageMetaKey, copy_to_device, + dtype_convert, dtype_numpy_to_torch, dtype_torch_to_numpy, ensure_tuple, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 084108132f..8873162d84 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -40,6 +40,7 @@ "list_to_dict", "dtype_torch_to_numpy", "dtype_numpy_to_torch", + "dtype_convert", "MAX_SEED", "copy_to_device", "get_dist_device", @@ -316,10 +317,30 @@ def dtype_torch_to_numpy(dtype): def dtype_numpy_to_torch(dtype): """Convert a numpy dtype to its torch equivalent.""" + # np dtypes can be given as np.float32 and np.dtype(np.float32) so unify them dtype = np.dtype(dtype) if type(dtype) == type else dtype return _np_to_torch_dtype[dtype] +def dtype_convert(dtype, data_type): + """Convert to the `dtype` that corresponds to `data_type`. + + Example: + im = torch.Tensor((1)) + dtype = dtype_convert(np.float32, type(im)) + """ + if data_type is torch.Tensor: + if type(dtype) is torch.dtype: + return dtype + return dtype_numpy_to_torch(dtype) + else: + # np dtypes can be given as np.float32 and np.dtype(np.float32) so unify them + dtype = np.dtype(dtype) if type(dtype) == type else dtype + if type(dtype) is np.dtype: + return dtype + return dtype_torch_to_numpy(dtype) + + def copy_to_device( obj: Any, device: Optional[Union[str, torch.device]], diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py new file mode 100644 index 0000000000..3a47e7d436 --- /dev/null +++ b/tests/test_convert_data_type.py @@ -0,0 +1,45 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import List, Tuple + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.transform import convert_data_type +from monai.utils.misc import dtype_convert + +TESTS: List[Tuple] = [] +TESTS.append((np.array, torch.Tensor, np.float32, torch.float32)) +TESTS.append((torch.Tensor, np.ndarray, np.float32, torch.float32)) +TESTS.append((np.array, torch.Tensor, torch.float32, np.float32)) +TESTS.append((torch.Tensor, np.ndarray, torch.float32, np.float32)) + + +class TestConvertDataType(unittest.TestCase): + @staticmethod + def get_im(im_type, dtype): + dtype = dtype_convert(dtype, im_type) + lib = torch if im_type is torch.Tensor else np + return lib.zeros((1, 2, 3), dtype=dtype) + + @parameterized.expand(TESTS) + def test_convert_data_type(self, in_type, out_type, in_dtype, out_dtype): + orig_im = self.get_im(in_type, in_dtype) + converted_im, orig_type, _ = convert_data_type(orig_im, out_type) + self.assertEqual(type(orig_im), orig_type) + self.assertEqual(type(converted_im), out_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dtype_convert.py b/tests/test_dtype_convert.py new file mode 100644 index 0000000000..b07f58989e --- /dev/null +++ b/tests/test_dtype_convert.py @@ -0,0 +1,36 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.utils.misc import dtype_convert + +TEST_NDARRAYS = [torch.Tensor, np.ndarray] +DTYPES = [torch.float32, np.float32, np.dtype(np.float32)] + +TESTS = [] +for im_type in TEST_NDARRAYS: + for im_dtype in DTYPES: + TESTS.append((im_type, im_dtype)) + + +class TestDtypeConvert(unittest.TestCase): + @parameterized.expand(TESTS) + def test_dtype_convert(self, im_type, desired_dtype): + out = dtype_convert(desired_dtype, im_type) + + +if __name__ == "__main__": + unittest.main() From c95f0595a1d95d72b6551dbe90211f914a298b40 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 14 Jun 2021 17:56:35 +0100 Subject: [PATCH 038/176] spatial pad, border pad, divisible pad Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 1 + monai/transforms/croppad/array.py | 94 ++++++++++++++++++++++++++----- tests/test_border_pad.py | 12 ++-- tests/test_spatial_pad.py | 79 +++++++++++++++++++------- 4 files changed, 145 insertions(+), 41 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2055a85edd..ef07bbe541 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -18,6 +18,7 @@ CenterSpatialCrop, CropForeground, DivisiblePad, + Pad, RandCropByPosNegLabel, RandScaleCrop, RandSpatialCrop, diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 058da71424..b87081a95c 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -18,6 +18,8 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import numpy as np +import torch +from torch.nn.functional import pad as pad_pt from monai.config import IndexSelection from monai.data.utils import get_random_patch, get_valid_patch_size @@ -34,6 +36,7 @@ from monai.utils.enums import DataObjects __all__ = [ + "Pad", "SpatialPad", "BorderPad", "DivisiblePad", @@ -51,9 +54,71 @@ ] -class SpatialPad(Transform): +class Pad(TorchOrNumpyTransform): + """ + Perform padding for a given an amount of padding in each dimension. + + If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used. + Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary). + Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad + for additional details. + + Args: + to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...]. + mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + One of the listed string values or a user supplied function. Defaults to ``"constant"``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ + + def __init__( + self, + to_pad: List[Tuple[int, int]], + mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + ) -> None: + self.to_pad = to_pad + self.mode = mode + + @staticmethod + def _np_pad(img: DataObjects.Images, all_pad_width, mode) -> DataObjects.Images: + out, orig_type, orig_device = convert_data_type(img, np.ndarray) + out = np.pad(out, all_pad_width, mode=mode) + out, *_ = convert_data_type(out, orig_type, orig_device) + return out + + @staticmethod + def _pt_pad(img: DataObjects.Images, all_pad_width, mode) -> DataObjects.Images: + out, orig_type, orig_device = convert_data_type(img, torch.Tensor) + pt_pad_width = [val for sublist in all_pad_width for val in sublist[::-1]][::-1] + out = pad_pt(out, pt_pad_width, mode=mode) # type: ignore + out, *_ = convert_data_type(out, orig_type, orig_device) + return out + + def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, str]] = None) -> DataObjects.Images: + """ + Args: + img: data to be transformed, assuming `img` is channel-first and + padding doesn't apply to the channel dim. + mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, + ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} + One of the listed string values or a user supplied function. Defaults to ``self.mode``. + See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ + if not np.asarray(self.to_pad).any(): + # all zeros, skip padding + return img + mode = mode or self.mode + mode = mode.value if isinstance(mode, NumpyPadMode) else mode + pad = self._pt_pad if isinstance(img, torch.Tensor) and mode == "constant" else self._np_pad + return pad(img, self.to_pad, mode) + + +class SpatialPad(TorchOrNumpyTransform): """ Performs padding to the data, symmetric for all sides or all on one side for each dimension. + + If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used. + Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary). Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad for additional details. @@ -76,7 +141,7 @@ def __init__( ) -> None: self.spatial_size = spatial_size self.method: Method = Method(method) - self.mode: NumpyPadMode = NumpyPadMode(mode) + self.mode = mode def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]: spatial_size = fall_back_tuple(self.spatial_size, data_shape) @@ -88,7 +153,7 @@ def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int return pad_width return [(0, max(sp_i - data_shape[i], 0)) for i, sp_i in enumerate(spatial_size)] - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: + def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, str]] = None) -> DataObjects.Images: """ Args: img: data to be transformed, assuming `img` is channel-first and @@ -103,11 +168,11 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N if not np.asarray(all_pad_width).any(): # all zeros, skip padding return img - img = np.pad(img, all_pad_width, mode=self.mode.value if mode is None else NumpyPadMode(mode).value) - return img + padder = Pad(all_pad_width, mode or self.mode) + return padder(img) -class BorderPad(Transform): +class BorderPad(TorchOrNumpyTransform): """ Pad the input data by adding specified borders to every dimension. @@ -133,9 +198,9 @@ def __init__( self, spatial_border: Union[Sequence[int], int], mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT ) -> None: self.spatial_border = spatial_border - self.mode: NumpyPadMode = NumpyPadMode(mode) + self.mode = mode - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None): + def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, str]] = None) -> DataObjects.Images: """ Args: img: data to be transformed, assuming `img` is channel-first and @@ -168,13 +233,12 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N f"Unsupported spatial_border length: {len(spatial_border)}, available options are " f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]." ) - - return np.pad( - img, [(0, 0)] + data_pad_width, mode=self.mode.value if mode is None else NumpyPadMode(mode).value - ) + all_pad_width = [(0, 0)] + data_pad_width + padder = Pad(all_pad_width, mode or self.mode) + return padder(img) -class DivisiblePad(Transform): +class DivisiblePad(TorchOrNumpyTransform): """ Pad the input data, so that the spatial sizes are divisible by `k`. """ @@ -193,9 +257,9 @@ def __init__(self, k: Union[Sequence[int], int], mode: Union[NumpyPadMode, str] See also :py:class:`monai.transforms.SpatialPad` """ self.k = k - self.mode: NumpyPadMode = NumpyPadMode(mode) + self.mode = mode - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: + def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, str]] = None) -> DataObjects.Images: """ Args: img: data to be transformed, assuming `img` is channel-first diff --git a/tests/test_border_pad.py b/tests/test_border_pad.py index 14d93aae4e..cfc44f5e66 100644 --- a/tests/test_border_pad.py +++ b/tests/test_border_pad.py @@ -16,6 +16,7 @@ from monai.transforms import BorderPad from monai.utils import NumpyPadMode +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [ {"spatial_border": 2, "mode": "constant"}, @@ -45,11 +46,12 @@ class TestBorderPad(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_pad_shape(self, input_param, input_data, expected_val): - padder = BorderPad(**input_param) - result = padder(input_data) - self.assertAlmostEqual(result.shape, expected_val.shape) - result = padder(input_data, mode=input_param["mode"]) - self.assertAlmostEqual(result.shape, expected_val.shape) + for p in TEST_NDARRAYS: + padder = BorderPad(**input_param) + result = padder(p(input_data)) + self.assertAlmostEqual(result.shape, expected_val.shape) + result = padder(p(input_data), mode=input_param["mode"]) + self.assertAlmostEqual(result.shape, expected_val.shape) if __name__ == "__main__": diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 4473a23770..0a5590ba57 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -10,39 +10,76 @@ # limitations under the License. import unittest +from typing import List import numpy as np from parameterized import parameterized from monai.transforms import SpatialPad +from monai.utils.enums import NumpyPadMode +from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"spatial_size": [15, 8, 8], "method": "symmetric", "mode": "constant"}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 15, 8, 8)), -] +TESTS = [] -TEST_CASE_2 = [ - {"spatial_size": [15, 8, 8], "method": "end", "mode": "constant"}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 15, 8, 8)), +# Numpy modes +MODES: List = [ + "constant", + "edge", + "linear_ramp", + "maximum", + "mean", + "median", + "minimum", + "reflect", + "symmetric", + "wrap", + "empty", ] +MODES += [NumpyPadMode(i) for i in MODES] -TEST_CASE_3 = [ - {"spatial_size": [15, 4, -1], "method": "symmetric", "mode": "constant"}, - np.zeros((3, 8, 8, 4)), - np.zeros((3, 15, 8, 4)), -] +for mode in MODES: + TESTS.append( + [ + {"spatial_size": [50, 50], "method": "end", "mode": mode}, + (1, 2, 2), + (1, 50, 50), + ] + ) + + TESTS.append( + [ + {"spatial_size": [15, 4, -1], "method": "symmetric", "mode": mode}, + (3, 8, 8, 4), + (3, 15, 8, 4), + ] + ) class TestSpatialPad(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_pad_shape(self, input_param, input_data, expected_val): - padder = SpatialPad(**input_param) - result = padder(input_data) - np.testing.assert_allclose(result.shape, expected_val.shape) - result = padder(input_data, mode=input_param["mode"]) - np.testing.assert_allclose(result.shape, expected_val.shape) + def setUp(self) -> None: + set_determinism(seed=0) + + def tearDown(self) -> None: + set_determinism(None) + + @staticmethod + def get_arr(shape): + return np.random.randint(100, size=shape).astype(float) + + @parameterized.expand(TESTS) + def test_pad_shape(self, input_param, input_shape, expected_shape): + results_1 = [] + results_2 = [] + input_data = self.get_arr(input_shape) + for p in TEST_NDARRAYS: + padder = SpatialPad(**input_param) + results_1.append(padder(p(input_data))) + results_2.append(padder(p(input_data), mode=input_param["mode"])) + for results in (results_1, results_2): + np.testing.assert_allclose(results[-1].shape, expected_shape) + if input_param["mode"] not in ("empty", NumpyPadMode.EMPTY): + np.testing.assert_allclose(results[0], results[-1], rtol=1e-5) if __name__ == "__main__": From 12d514d32127e5b3e65a5deac9aafc744901159c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 14 Jun 2021 17:19:31 +0000 Subject: [PATCH 039/176] update convert dtype Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/utils/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 8873162d84..cb1d1868ef 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -336,7 +336,7 @@ def dtype_convert(dtype, data_type): else: # np dtypes can be given as np.float32 and np.dtype(np.float32) so unify them dtype = np.dtype(dtype) if type(dtype) == type else dtype - if type(dtype) is np.dtype: + if isinstance(dtype, np.dtype): return dtype return dtype_torch_to_numpy(dtype) From bc78766753f2efeff03cb9f670d33d2cfaa0b714 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 14 Jun 2021 17:21:48 +0000 Subject: [PATCH 040/176] update convert_dtype Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/utils/misc.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index cb1d1868ef..e4331f1ce4 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -334,9 +334,7 @@ def dtype_convert(dtype, data_type): return dtype return dtype_numpy_to_torch(dtype) else: - # np dtypes can be given as np.float32 and np.dtype(np.float32) so unify them - dtype = np.dtype(dtype) if type(dtype) == type else dtype - if isinstance(dtype, np.dtype): + if type(dtype) is not torch.dtype: return dtype return dtype_torch_to_numpy(dtype) From 5208e8d2422584931424a3cb8d57beb40ff96199 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 14 Jun 2021 17:30:26 +0000 Subject: [PATCH 041/176] gibbs_noised returns same type as input Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_gibbs_noised.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index b0db79fe4f..40e430abc6 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -53,7 +53,7 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): out2 = t(deepcopy(data)) for k in KEYS: np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) + self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) def test_identity(self, im_shape, _, as_tensor_input): From 80aa083b39e3d065f561c1d4fb3f1057d128440e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 14 Jun 2021 18:49:39 +0100 Subject: [PATCH 042/176] rand_rician_noised Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/dictionary.py | 4 +++ tests/test_rand_rician_noised.py | 44 +++++++++++++----------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 26e5d3a86c..eb5499d59f 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -205,6 +205,10 @@ def __init__( RandomizableTransform.__init__(self, global_prob) self.rand_rician_noise = RandRicianNoise(prob, mean, std, channel_wise, relative, sample_std) + def set_random_state(self, seed=None, state=None): + super().set_random_state(seed, state) + self.rand_rician_noise.set_random_state(seed, state) + def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Images]: d = dict(data) super().randomize(None) diff --git a/tests/test_rand_rician_noised.py b/tests/test_rand_rician_noised.py index 3dbfce154d..cc5d1c4a84 100644 --- a/tests/test_rand_rician_noised.py +++ b/tests/test_rand_rician_noised.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandRicianNoised @@ -24,36 +25,39 @@ seed = 0 -def test_numpy_or_torch(keys, mean, std, imt): - rician_fn = RandRicianNoised(keys=keys, global_prob=1.0, prob=1.0, mean=mean, std=std) - rician_fn.set_random_state(seed) - rician_fn.rand_rician_noise.set_random_state(seed) - noised = rician_fn({k: imt for k in keys}) - np.random.seed(seed) - np.random.random() - np.random.seed(seed) - for k in keys: - np.random.random() - _std = np.random.uniform(0, std) - expected = np.sqrt( - (imt + np.random.normal(mean, _std, size=imt.shape)) ** 2 - + np.random.normal(mean, _std, size=imt.shape) ** 2 - ) - np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) - - # Test with numpy class TestRandRicianNoisedNumpy(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES) def test_correct_results(self, _, keys, mean, std): - test_numpy_or_torch(keys, mean, std, self.imt) + rician_fn = RandRicianNoised(keys=keys, global_prob=1.0, prob=1.0, mean=mean, std=std) + rician_fn.set_random_state(seed) + noised = rician_fn({k: self.imt for k in keys}) + np.random.seed(seed) + for k in keys: + np.random.random() + _std = np.random.uniform(0, std) + expected = np.sqrt( + (self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2 + + np.random.normal(mean, _std, size=self.imt.shape) ** 2 + ) + np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) # Test with torch class TestRandRicianNoisedTorch(TorchImageTestCase2D): @parameterized.expand(TEST_CASES) def test_correct_results(self, _, keys, mean, std): - test_numpy_or_torch(keys, mean, std, self.imt) + rician_fn = RandRicianNoised(keys=keys, global_prob=1.0, prob=1.0, mean=mean, std=std) + rician_fn.set_random_state(seed) + noised = rician_fn({k: self.imt for k in keys}) + torch.manual_seed(seed) + for k in keys: + _std = float(torch.rand(1)) * std + expected = torch.sqrt( + (self.imt + torch.normal(mean, _std, size=self.imt.shape)) ** 2 + + torch.normal(mean, _std, size=self.imt.shape) ** 2 + ) + np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) if __name__ == "__main__": From cfb1ff0d92a871c977211ecca76846bd82ae0cc6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 15 Jun 2021 09:50:37 +0100 Subject: [PATCH 043/176] nifti writer works with torch.tensor Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/nifti_writer.py | 40 ++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index c56d4c1e8d..96076ba94c 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -17,13 +17,15 @@ from monai.config import DtypeLike from monai.data.utils import compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform +from monai.transforms.transform import convert_data_type from monai.utils import GridSampleMode, GridSamplePadMode, optional_import +from monai.utils.enums import DataObjects nib, _ = optional_import("nibabel") def write_nifti( - data: np.ndarray, + data: DataObjects.Images, file_name: str, affine: Optional[np.ndarray] = None, target_affine: Optional[np.ndarray] = None, @@ -36,7 +38,7 @@ def write_nifti( output_dtype: DtypeLike = np.float32, ) -> None: """ - Write numpy data into NIfTI files to disk. This function converts data + Write numpy or torch data into NIfTI files to disk. This function converts data into the coordinate system defined by `target_affine` when `target_affine` is specified. @@ -96,10 +98,14 @@ def write_nifti( If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. """ - if not isinstance(data, np.ndarray): - raise AssertionError("input data must be numpy array.") - dtype = dtype or data.dtype - sr = min(data.ndim, 3) + if not isinstance(data, (np.ndarray, torch.Tensor)): + raise AssertionError("input data must be numpy array or torch.Tensor.") + # if torch, convert to numpy + data_np: np.ndarray + data_np, *_ = convert_data_type(data, np.ndarray) # type: ignore + + dtype = dtype or data_np.dtype + sr = min(data_np.ndim, 3) if affine is None: affine = np.eye(4, dtype=np.float64) affine = to_affine_nd(sr, affine) @@ -110,7 +116,7 @@ def write_nifti( if np.allclose(affine, target_affine, atol=1e-3): # no affine changes, save (data, affine) - results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, target_affine)) + results_img = nib.Nifti1Image(data_np.astype(output_dtype), to_affine_nd(3, target_affine)) nib.save(results_img, file_name) return @@ -118,11 +124,11 @@ def write_nifti( start_ornt = nib.orientations.io_orientation(affine) target_ornt = nib.orientations.io_orientation(target_affine) ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) - data_shape = data.shape - data = nib.orientations.apply_orientation(data, ornt_transform) + data_shape = data_np.shape + data_np = nib.orientations.apply_orientation(data_np, ornt_transform) _affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) if np.allclose(_affine, target_affine, atol=1e-3) or not resample: - results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, _affine)) + results_img = nib.Nifti1Image(data_np.astype(output_dtype), to_affine_nd(3, _affine)) nib.save(results_img, file_name) return @@ -132,13 +138,13 @@ def write_nifti( ) transform = np.linalg.inv(_affine) @ target_affine if output_spatial_shape is None: - output_spatial_shape, _ = compute_shape_offset(data.shape, _affine, target_affine) + output_spatial_shape, _ = compute_shape_offset(data_np.shape, _affine, target_affine) output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else [] - if data.ndim > 3: # multi channel, resampling each channel + if data_np.ndim > 3: # multi channel, resampling each channel while len(output_spatial_shape_) < 3: output_spatial_shape_ = output_spatial_shape_ + [1] - spatial_shape, channel_shape = data.shape[:3], data.shape[3:] - data_np = data.reshape(list(spatial_shape) + [-1]) + spatial_shape, channel_shape = data_np.shape[:3], data_np.shape[3:] + data_np = data_np.reshape(list(spatial_shape) + [-1]) data_np = np.moveaxis(data_np, -1, 0) # channel first for pytorch data_torch = affine_xform( torch.as_tensor(np.ascontiguousarray(data_np).astype(dtype)).unsqueeze(0), @@ -149,12 +155,12 @@ def write_nifti( data_np = np.moveaxis(data_np, 0, -1) # channel last for nifti data_np = data_np.reshape(list(data_np.shape[:3]) + list(channel_shape)) else: # single channel image, need to expand to have batch and channel - while len(output_spatial_shape_) < len(data.shape): + while len(output_spatial_shape_) < len(data_np.shape): output_spatial_shape_ = output_spatial_shape_ + [1] data_torch = affine_xform( - torch.as_tensor(np.ascontiguousarray(data).astype(dtype)[None, None]), + torch.as_tensor(np.ascontiguousarray(data_np).astype(dtype)[None, None]), torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)), - spatial_size=output_spatial_shape_[: len(data.shape)], + spatial_size=output_spatial_shape_[: len(data_np.shape)], ) data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy() From db5a094a261cdf90e6fa31403016db83f82b3336 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 15 Jun 2021 09:55:47 +0100 Subject: [PATCH 044/176] test_image_dataset check np or torch dtype Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/io/array.py | 4 ++-- monai/transforms/transform.py | 2 +- tests/test_image_dataset.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 83c73ec25f..cfa3a229ec 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -23,7 +23,7 @@ from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.data.nifti_saver import NiftiSaver from monai.data.png_saver import PNGSaver -from monai.transforms.transform import Transform +from monai.transforms.transform import Transform, convert_data_type from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode, ensure_tuple, optional_import @@ -171,7 +171,7 @@ def __call__( # convert to desired output type if self.as_tensor: - img_array = torch.Tensor(img_array) + img_array, *_ = convert_data_type(img_array, torch.Tensor) if self.image_only: return img_array diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index ffe98c1b5b..14de89b796 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -53,7 +53,7 @@ def convert_data_type( if output_type is torch.Tensor: if orig_type is np.ndarray: - data = torch.Tensor(np.ascontiguousarray(data)) + data = torch.tensor(np.ascontiguousarray(data)) if dtype != data.dtype: data = data.to(dtype) # type: ignore elif output_type is np.ndarray: diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index 173d24f350..eda36a14ca 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -18,6 +18,7 @@ from monai.data import ImageDataset from monai.transforms import Compose, EnsureChannelFirst, RandAdjustContrast, RandomizableTransform, Spacing +from monai.utils.misc import dtype_convert FILENAMES = ["test1.nii.gz", "test2.nii", "test3.nii.gz"] @@ -86,7 +87,7 @@ def test_dataset(self): # loading no meta, int dataset = ImageDataset(full_names, dtype=np.float16) for d, _ in zip(dataset, ref_data): - self.assertEqual(d.dtype, np.float16) + self.assertEqual(d.dtype, dtype_convert(np.float16, type(d))) # loading with meta, no transform dataset = ImageDataset(full_names, image_only=False) From 2d9258b105001044412259ea0108155d7117bbed Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 15 Jun 2021 10:32:45 +0100 Subject: [PATCH 045/176] guassian smooth needs float type Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 47bd331755..7b078967bc 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -958,6 +958,7 @@ def __init__(self, sigma: Union[Sequence[float], float] = 1.0, approx: str = "er def __call__(self, img: DataObjects.Images) -> DataObjects.Images: img_t: torch.Tensor img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + img_t = img_t.to(torch.float) gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx) out = gaussian_filter(img_t.unsqueeze(0)).squeeze(0) From 8764a7b5ba14d91b7be0fd6d090d9fb6011365ca Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 15 Jun 2021 10:38:15 +0100 Subject: [PATCH 046/176] fix tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 2 +- tests/test_load_spacing_orientation.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 44795ee5a9..6b440502eb 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -659,7 +659,7 @@ def __call__( out = np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data - return self.post_convert_data(out, *orig_info) # type: ignore + return self.post_convert_data(out, orig_type, orig_device) # type: ignore class FgBgToIndices(Transform): diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index ac9b62f27b..5e92e8dd37 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -36,7 +36,7 @@ def test_load_spacingd(self, filename): res_dict = Spacingd(keys="image", pixdim=(1, 0.2, 1), diagonal=True, padding_mode="zeros")(data_dict) t1 = time.time() print(f"time monai: {t1 - t}") - im_np, _ = convert_data_type(data_dict["image"][0], np.ndarray) + im_np, *_ = convert_data_type(data_dict["image"][0], np.ndarray) anat = nibabel.Nifti1Image(im_np, data_dict["image_meta_dict"]["original_affine"]) ref = resample_to_output(anat, (1, 0.2, 1), order=1) t2 = time.time() @@ -59,7 +59,7 @@ def test_load_spacingd_rotate(self, filename): res_dict = Spacingd(keys="image", pixdim=(1, 2, 3), diagonal=True, padding_mode="zeros")(data_dict) t1 = time.time() print(f"time monai: {t1 - t}") - im_np, _ = convert_data_type(data_dict["image"][0], np.ndarray) + im_np, *_ = convert_data_type(data_dict["image"][0], np.ndarray) anat = nibabel.Nifti1Image(im_np, data_dict["image_meta_dict"]["original_affine"]) ref = resample_to_output(anat, (1, 2, 3), order=1) t2 = time.time() From ab9457e77c5ef75d52329ffe8821e0733b455a2a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 16 Jun 2021 12:05:56 +0100 Subject: [PATCH 047/176] fix gaussiansharpen Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 7b078967bc..9af6276237 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1053,6 +1053,7 @@ def __init__( def __call__(self, img: DataObjects.Images) -> DataObjects.Images: img_t: torch.Tensor img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + img_t = img_t.to(torch.float) gf1, gf2 = [GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx) for sigma in (self.sigma1, self.sigma2)] blurred_f = gf1(img_t.unsqueeze(0)) From b2699fc4a9b16daae693797b3a18b5bfeb707b54 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 16 Jun 2021 12:09:51 +0100 Subject: [PATCH 048/176] LoadImaged allows torch.tensor Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/io/dictionary.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index c3ecddb4e1..621e539ac1 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -18,6 +18,7 @@ from typing import Optional, Union import numpy as np +import torch from monai.config import DtypeLike, KeysCollection from monai.data.image_reader import ImageReader @@ -114,7 +115,7 @@ def __call__(self, data, reader: Optional[ImageReader] = None): for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): data = self._loader(d[key], reader) if self._loader.image_only: - if not isinstance(data, np.ndarray): + if not isinstance(data, (torch.Tensor, np.ndarray)): raise ValueError("loader must return a numpy array (because image_only=True was used).") d[key] = data else: From e366205fb426c370c2f28b26288f4078815bb50c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 16 Jun 2021 12:14:24 +0100 Subject: [PATCH 049/176] allow conversion from ITK Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 14de89b796..c594dd73fd 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -52,7 +52,7 @@ def convert_data_type( dtype = dtype_convert(dtype or data.dtype, output_type) if output_type is torch.Tensor: - if orig_type is np.ndarray: + if orig_type is not torch.Tensor: data = torch.tensor(np.ascontiguousarray(data)) if dtype != data.dtype: data = data.to(dtype) # type: ignore From 8c622d2b63e93e4b39d66bff49dc245e76e87834 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 28 Jun 2021 18:14:33 +0100 Subject: [PATCH 050/176] post-merge updates Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 2 +- monai/transforms/intensity/array.py | 11 +++-------- monai/transforms/intensity/dictionary.py | 2 +- monai/transforms/transform.py | 3 ++- monai/transforms/utility/array.py | 2 +- tests/test_decollate.py | 3 ++- 6 files changed, 10 insertions(+), 13 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index e95657aec6..d8e82b0881 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -20,7 +20,7 @@ from enum import Enum from itertools import chain from math import ceil, floor -from typing import Any, Callable, Dict, Hashable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 5fbd6efe5a..6a0b31cd30 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -32,14 +32,9 @@ convert_data_type, ) from monai.transforms.utils import rescale_array -from monai.utils import ( - PT_BEFORE_1_7, - InvalidPyTorchVersionError, - dtype_torch_to_numpy, - ensure_tuple, - ensure_tuple_rep, - ensure_tuple_size, -) +from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, ensure_tuple, ensure_tuple_rep, ensure_tuple_size +from monai.utils.enums import DataObjects +from monai.utils.misc import dtype_convert __all__ = [ "RandGaussianNoise", diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 9b8fa7e0ff..fc5ec479ad 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -16,7 +16,7 @@ """ from collections.abc import Iterable -from typing import Any, Dict, Hashable, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index a36194ee54..4b9fa473c7 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Generator, Hashable, Iterable, List, Optional, Tuple, Type, TypeVar, Union import numpy as np import torch @@ -39,6 +39,7 @@ ReturnType = TypeVar("ReturnType") + def convert_data_type( data: DataObjects.Images, output_type: Optional[type] = None, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 4b1f6fd9ee..6b440502eb 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -32,7 +32,7 @@ convert_data_type, ) from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices -from monai.utils import ensure_tuple, issequenceiterable, min_version, optional_import +from monai.utils import ensure_tuple, min_version, optional_import from monai.utils.enums import DataObjects PILImageImage, has_pil = optional_import("PIL.Image", name="Image") diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 813c849fab..b18857d82e 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -120,7 +120,8 @@ def check_match(self, in1, in2): def check_decollate(self, dataset): batch_size = 2 - num_workers = 2 + # num workers = 0 for mac or gpu transforms + num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) From 8908a99f712fc5cfd4e954c291639446c3228bf6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 28 Jun 2021 18:46:30 +0100 Subject: [PATCH 051/176] fix to_tensor 0 dim Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 11 ++++++++--- tests/test_to_tensor.py | 12 +++++++----- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 4b9fa473c7..d34f5d892b 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -51,16 +51,21 @@ def convert_data_type( orig_device = data.device if isinstance(data, torch.Tensor) else None output_type = output_type or orig_type - dtype = dtype_convert(dtype or data.dtype, output_type) + # objects like float don't have dtype, so return their type + dtype = dtype_convert(dtype or data.dtype, output_type) if hasattr(data, "dtype") else type(data) if output_type is torch.Tensor: - if orig_type is not torch.Tensor: - data = torch.tensor(np.ascontiguousarray(data)) + if orig_type is np.ndarray: + data = torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data)) + else: + data = torch.as_tensor(data) if dtype != data.dtype: data = data.to(dtype) # type: ignore elif output_type is np.ndarray: if orig_type is torch.Tensor: data = data.detach().cpu().numpy() # type: ignore + else: + data = np.array(data) if dtype != data.dtype: data = data.astype(dtype) # type: ignore diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index 6a36786914..e35a493bb8 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -58,10 +58,12 @@ def test_list_tuple(self): result = ToTensor()(test_data) torch.testing.assert_allclose(result, test_data) + def test_single_input(self): + for test_data in (5, np.asarray(5), torch.tensor(5)): + result = ToTensor()(test_data) + torch.testing.assert_allclose(result, test_data) + self.assertEqual(result.ndim, 0) + if __name__ == "__main__": - # unittest.main() - a = TestToTensor() - a.test_list_tuple() - a.test_numpy_input() - a.test_tensor_input() + unittest.main() From 5b382e3963edf78e7b360624026335378adbaab0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 2 Jul 2021 15:10:35 +0100 Subject: [PATCH 052/176] mypy Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 2 +- monai/transforms/post/array.py | 29 +++++++++++++------------- monai/transforms/spatial/array.py | 2 +- tests/test_squeezedim.py | 12 +++++------ tests/test_transpose.py | 6 +++--- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index d8e82b0881..a05f2377bd 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -461,7 +461,7 @@ def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, DataObjects.Images]) -> Dict[Hashable, DataObjects.Images]: d = deepcopy(dict(data)) for key in self.key_iterator(d): diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index fdca3a0c6c..8e37e89473 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -257,40 +257,41 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: Returns: A PyTorch Tensor with shape (C, spatial_dim1[, spatial_dim2, ...]). """ - img, orig_type, orig_device = self.pre_conv_data(img) + img_t: torch.Tensor + img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore - if img.shape[0] == 1: - img = torch.squeeze(img, dim=0) + if img_t.shape[0] == 1: + img_t = torch.squeeze(img_t, dim=0) if self.independent: for i in self.applied_labels: - foreground = (img == i).type(torch.uint8) + foreground = (img_t == i).type(torch.uint8) mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[foreground != mask] = 0 + img_t[foreground != mask] = 0 else: - foreground = torch.zeros_like(img) + foreground = torch.zeros_like(img_t) for i in self.applied_labels: - foreground += (img == i).type(torch.uint8) + foreground += (img_t == i).type(torch.uint8) mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[foreground != mask] = 0 + img_t[foreground != mask] = 0 - output = torch.unsqueeze(img, dim=0) + output = torch.unsqueeze(img_t, dim=0) else: # one-hot data is assumed to have binary value in each channel if self.independent: for i in self.applied_labels: - foreground = img[i, ...].type(torch.uint8) + foreground = img_t[i, ...].type(torch.uint8) mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[i, ...][foreground != mask] = 0 + img_t[i, ...][foreground != mask] = 0 else: - applied_img = img[self.applied_labels, ...].type(torch.uint8) + applied_img = img_t[self.applied_labels, ...].type(torch.uint8) foreground = torch.any(applied_img, dim=0) mask = get_largest_connected_component_mask(foreground, self.connectivity) background_mask = torch.unsqueeze(foreground != mask, dim=0) background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=0) applied_img[background_mask] = 0 - img[self.applied_labels, ...] = applied_img.type(img.type()) # type: ignore - output = img # type: ignore + img_t[self.applied_labels, ...] = applied_img.type(img_t.type()) # type: ignore + output = img_t # type: ignore return self.post_convert_data(output, orig_type, orig_device) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index cf2d224810..8afc05078a 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1339,7 +1339,7 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Tuple[DataObjects.Images, DataObjects.Images]: + ) -> Union[DataObjects.Images, Tuple[DataObjects.Images, DataObjects.Images]]: """ Args: img: shape must be (num_channels, H, W[, D]), diff --git a/tests/test_squeezedim.py b/tests/test_squeezedim.py index d9ab217b9e..0d627a8ca8 100644 --- a/tests/test_squeezedim.py +++ b/tests/test_squeezedim.py @@ -19,13 +19,13 @@ TESTS, TESTS_FAIL = [], [] for r in (np.random.rand, torch.rand): - TESTS.append([{"dim": None}, r(1, 2, 1, 3), (2, 3)]) - TESTS.append([{"dim": 2}, r(1, 2, 1, 8, 16), (1, 2, 8, 16)]) - TESTS.append([{"dim": -1}, r(1, 1, 16, 8, 1), (1, 1, 16, 8)]) - TESTS.append([{}, r(1, 2, 1, 3), (2, 1, 3)]) + TESTS.append([{"dim": None}, r(1, 2, 1, 3), (2, 3)]) # type: ignore + TESTS.append([{"dim": 2}, r(1, 2, 1, 8, 16), (1, 2, 8, 16)]) # type: ignore + TESTS.append([{"dim": -1}, r(1, 1, 16, 8, 1), (1, 1, 16, 8)]) # type: ignore + TESTS.append([{}, r(1, 2, 1, 3), (2, 1, 3)]) # type: ignore - TESTS_FAIL.append([ValueError, {"dim": -2}, r(1, 1, 16, 8, 1)]) - TESTS_FAIL.append([TypeError, {"dim": 0.5}, r(1, 1, 16, 8, 1)]) + TESTS_FAIL.append([ValueError, {"dim": -2}, r(1, 1, 16, 8, 1)]) # type: ignore + TESTS_FAIL.append([TypeError, {"dim": 0.5}, r(1, 1, 16, 8, 1)]) # type: ignore class TestSqueezeDim(unittest.TestCase): diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 9ba9eeb4bc..3c30fba281 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -18,9 +18,9 @@ from monai.transforms import Transpose TEST_CASES = [] -for q in (np.arange, torch.Tensor): - TEST_CASES.append([q(5 * 4).reshape(5, 4), None]) - TEST_CASES.append([q(5 * 4 * 3).reshape(5, 4, 3), [2, 0, 1]]) +for q in (np.arange, torch.arange): + TEST_CASES.append([q(5 * 4).reshape(5, 4), None]) # type: ignore + TEST_CASES.append([q(5 * 4 * 3).reshape(5, 4, 3), [2, 0, 1]]) # type: ignore class TestTranspose(unittest.TestCase): From e2926b8e354c3cdb2bec05c865a68c560cd50801 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 5 Jul 2021 10:18:43 +0100 Subject: [PATCH 053/176] add temp test script Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/rb_test_transforms.py | 63 +++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 tests/rb_test_transforms.py diff --git a/tests/rb_test_transforms.py b/tests/rb_test_transforms.py new file mode 100644 index 0000000000..a182007954 --- /dev/null +++ b/tests/rb_test_transforms.py @@ -0,0 +1,63 @@ +from inspect import getmembers, isclass + +from monai import transforms +from monai.transforms import MapTransform, Transform +from monai.transforms.transform import NumpyTransform, ToDoTransform, TorchOrNumpyTransform, TorchTransform + + +class Colours: + red = "91" + green = "92" + yellow = "93" + light_purple = "94" + purple = "95" + cyan = "96" + light_gray = "97" + black = "98" + + +def print_colour(t, colour): + print(f"\033[{colour}m{t}\033[00m") + + +tr_total = 0 +tr_t_or_np = 0 +tr_t = 0 +tr_np = 0 +tr_todo = 0 +tr_uncategorised = 0 +for n, obj in getmembers(transforms): + if isclass(obj) and issubclass(obj, Transform) and not issubclass(obj, MapTransform): + if n in [ + "Transform", + "InvertibleTransform", + "Lambda", + "LoadImage", + "Compose", + "RandomizableTransform", + "ToPIL", + "ToCupy", + ]: + continue + tr_total += 1 + if issubclass(obj, TorchOrNumpyTransform): + tr_t_or_np += 1 + print_colour(f"TorchOrNumpy: {n}", Colours.green) + elif issubclass(obj, TorchTransform): + tr_t += 1 + print_colour(f"Torch: {n}", Colours.green) + elif issubclass(obj, NumpyTransform): + tr_np += 1 + print_colour(f"Numpy: {n}", Colours.yellow) + elif issubclass(obj, ToDoTransform): + tr_todo += 1 + print_colour(f"ToDoTransform: {n}", Colours.purple) + else: + tr_uncategorised += 1 + print_colour(f"Uncategorised: {n}", Colours.red) +print("Total number of transforms:", tr_total) +print_colour(f"Number of TorchOrNumpyTransform: {tr_t_or_np}", Colours.green) +print_colour(f"Number of TorchTransform: {tr_t}", Colours.green) +print_colour(f"Number of NumpyTransform: {tr_np}", Colours.yellow) +print_colour(f"Number of ToDoTransform: {tr_todo}", Colours.purple) +print_colour(f"Number of uncategorised: {tr_uncategorised}", Colours.red) From 100e69ef708bf1ffd7080a1b91ac13378ae7408c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 5 Jul 2021 15:42:25 +0100 Subject: [PATCH 054/176] update Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/visualize/img2tensorboard.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 63b68ba457..85a19b4c0f 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -9,12 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, Optional, Sequence import numpy as np import torch -from monai.config import NdarrayTensor from monai.transforms import rescale_array from monai.utils import optional_import from monai.utils.enums import DataObjects @@ -189,7 +188,7 @@ def plot_2d_or_3d_image( d: np.ndarray = data_index.detach().cpu().numpy() if isinstance(data_index, torch.Tensor) else data_index if d.ndim == 2: - d = rescale_array(d, 0, 1) + d = rescale_array(d, 0, 1) # type: ignore dataformats = "HW" writer.add_image(f"{tag}_{dataformats}", d, step, dataformats=dataformats) return From 7d889d6a4db473544a55faebc684bf2773e5ee4c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 5 Jul 2021 16:21:09 +0100 Subject: [PATCH 055/176] as_discrete test numpy and torch Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/post/array.py | 2 +- tests/test_as_discrete.py | 47 ++++++++++++++++++---------------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 8e37e89473..99bb9fdebf 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -169,7 +169,7 @@ def __call__( img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore if argmax or self.argmax: - img_t = torch.argmax(img_t, dim=1, keepdim=True) + img_t = torch.argmax(img_t, dim=0, keepdim=True) if to_onehot or self.to_onehot: _nclasses = self.n_classes if n_classes is None else n_classes diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index 658a21efd6..a2332f9274 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -15,31 +15,34 @@ from parameterized import parameterized from monai.transforms import AsDiscrete - -TEST_CASE_1 = [ - {"argmax": True, "to_onehot": False, "n_classes": None, "threshold_values": False, "logit_thresh": 0.5}, - torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - torch.tensor([[[1.0, 1.0]]]), - (1, 1, 2), -] - -TEST_CASE_2 = [ - {"argmax": True, "to_onehot": True, "n_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, - torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), - (2, 1, 2), -] - -TEST_CASE_3 = [ - {"argmax": False, "to_onehot": False, "n_classes": None, "threshold_values": True, "logit_thresh": 0.6}, - torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), - (1, 2, 2), -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([ + {"argmax": True, "to_onehot": False, "n_classes": None, "threshold_values": False, "logit_thresh": 0.5}, + p([[[0.0, 1.0]], [[2.0, 3.0]]]), + p([[[1.0, 1.0]]]), + (1, 1, 2), + ]) + + TESTS.append([ + {"argmax": True, "to_onehot": True, "n_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, + p([[[0.0, 1.0]], [[2.0, 3.0]]]), + p([[[0.0, 0.0]], [[1.0, 1.0]]]), + (2, 1, 2), + ]) + + TESTS.append([ + {"argmax": False, "to_onehot": False, "n_classes": None, "threshold_values": True, "logit_thresh": 0.6}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + p([[[0.0, 1.0], [1.0, 1.0]]]), + (1, 2, 2), + ]) class TestAsDiscrete(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value_shape(self, input_param, img, out, expected_shape): result = AsDiscrete(**input_param)(img) torch.testing.assert_allclose(result, out) From 8dbe874f74bc0ccc04ee4674570d94c476e43949 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 5 Jul 2021 16:55:05 +0100 Subject: [PATCH 056/176] padder np_kwargs Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 22 +++++++++++------- tests/test_divisible_pad.py | 37 ++++++++++++++++++------------- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 404392cac2..5be402fa84 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -75,22 +75,24 @@ def __init__( self, to_pad: List[Tuple[int, int]], mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + **np_kwargs, ) -> None: self.to_pad = to_pad self.mode = mode + self.np_kwargs = np_kwargs @staticmethod - def _np_pad(img: DataObjects.Images, all_pad_width, mode) -> DataObjects.Images: + def _np_pad(img: DataObjects.Images, all_pad_width, mode, **np_kwargs) -> DataObjects.Images: out, orig_type, orig_device = convert_data_type(img, np.ndarray) - out = np.pad(out, all_pad_width, mode=mode) + out = np.pad(out, all_pad_width, mode=mode, **np_kwargs) out, *_ = convert_data_type(out, orig_type, orig_device) return out @staticmethod - def _pt_pad(img: DataObjects.Images, all_pad_width, mode) -> DataObjects.Images: + def _pt_pad(img: DataObjects.Images, all_pad_width, mode, **np_kwargs) -> DataObjects.Images: out, orig_type, orig_device = convert_data_type(img, torch.Tensor) pt_pad_width = [val for sublist in all_pad_width for val in sublist[::-1]][::-1] - out = pad_pt(out, pt_pad_width, mode=mode) # type: ignore + out = pad_pt(out, pt_pad_width, mode=mode, **np_kwargs) # type: ignore out, *_ = convert_data_type(out, orig_type, orig_device) return out @@ -109,8 +111,11 @@ def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, s return img mode = mode or self.mode mode = mode.value if isinstance(mode, NumpyPadMode) else mode - pad = self._pt_pad if isinstance(img, torch.Tensor) and mode == "constant" else self._np_pad - return pad(img, self.to_pad, mode) + if isinstance(img, torch.Tensor) and mode == "constant" and not self.np_kwargs: + pad = self._pt_pad + else: + pad = self._np_pad + return pad(img, self.to_pad, mode, **self.np_kwargs) class SpatialPad(TorchOrNumpyTransform): @@ -171,12 +176,13 @@ def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, s One of the listed string values or a user supplied function. Defaults to ``self.mode``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ + mode = NumpyPadMode(mode or self.mode) data_pad_width = self._determine_data_pad_width(img.shape[1:]) all_pad_width = [(0, 0)] + data_pad_width if not np.asarray(all_pad_width).any(): # all zeros, skip padding return img - padder = Pad(all_pad_width, mode or self.mode) + padder = Pad(all_pad_width, mode or self.mode, **self.np_kwargs) return padder(img) @@ -279,7 +285,7 @@ def __init__( See also :py:class:`monai.transforms.SpatialPad` """ self.k = k - self.mode: NumpyPadMode = NumpyPadMode(mode) + self.mode = mode self.np_kwargs = np_kwargs def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, str]] = None) -> DataObjects.Images: diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index 1ca7e9f46a..9ef096a1e8 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -15,24 +15,28 @@ from parameterized import parameterized from monai.transforms import DivisiblePad +from tests.utils import TEST_NDARRAYS -# pad first dim to be divisible by 7, the second unchanged. -TEST_CASE_1 = [ - {"k": (7, -1), "mode": "constant"}, - np.zeros((3, 8, 7)), - np.zeros((3, 14, 7)), -] +TESTS = [] -# pad all dimensions to be divisible by 5 -TEST_CASE_2 = [ - {"k": 5, "mode": "constant"}, - np.zeros((3, 10, 5, 17)), - np.zeros((3, 10, 5, 20)), -] +for p in TEST_NDARRAYS: + # pad first dim to be divisible by 7, the second unchanged. + TESTS.append([ + {"k": (7, -1), "mode": "constant"}, + p(np.zeros((3, 8, 7))), + p(np.zeros((3, 14, 7))), + ]) + + # pad all dimensions to be divisible by 5 + TESTS.append([ + {"k": 5, "mode": "constant"}, + p(np.zeros((3, 10, 5, 17))), + p(np.zeros((3, 10, 5, 20))), + ]) class TestDivisiblePad(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_pad_shape(self, input_param, input_data, expected_val): padder = DivisiblePad(**input_param) result = padder(input_data) @@ -42,9 +46,10 @@ def test_pad_shape(self, input_param, input_data, expected_val): def test_pad_kwargs(self): padder = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) - result = padder(np.zeros((3, 8, 4))) - np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4))) - np.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1) + for p in TEST_NDARRAYS: + result = padder(p(np.zeros((3, 8, 4)))) + np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4))) + np.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1) if __name__ == "__main__": From 4e843f50b50326c35a41ec94b965c18056f81722 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 5 Jul 2021 16:57:01 +0100 Subject: [PATCH 057/176] fix borderpad Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 5be402fa84..c518bd29e1 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -255,7 +255,7 @@ def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, s f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]." ) all_pad_width = [(0, 0)] + data_pad_width - padder = Pad(all_pad_width, mode or self.mode) + padder = Pad(all_pad_width, mode or self.mode, **self.np_kwargs) return padder(img) From ee0f3faa42b26445def9b320a547f480fd5fa908 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 5 Jul 2021 17:02:30 +0100 Subject: [PATCH 058/176] fix test_invertd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 1ceabcd83b..4ce05a4ea6 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -14,6 +14,7 @@ """ import logging +from monai.utils.misc import dtype_convert import sys import time import warnings @@ -303,10 +304,11 @@ def __call__( TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``. """ + dtype = dtype_convert(dtype or self.dtype, type(img)) if isinstance(img, np.ndarray): - return img.astype(dtype or self.dtype) # type: ignore + return img.astype(dtype) # type: ignore if isinstance(img, torch.Tensor): - return img.to(dtype=dtype or self.dtype) # type: ignore + return img.to(dtype=dtype) # type: ignore raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.") From 94dc56bed5da7bc4aa78bdee8766d300a819a280 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 10:49:39 +0100 Subject: [PATCH 059/176] fix gaussian smooth Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 2 +- tests/test_gaussian_smooth.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 6a0b31cd30..5fa534877b 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -962,7 +962,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore img_t = img_t.to(torch.float) - gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx) + gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx).to(img_t.device) out = gaussian_filter(img_t.unsqueeze(0)).squeeze(0) return self.post_convert_data(out, orig_type, orig_device) diff --git a/tests/test_gaussian_smooth.py b/tests/test_gaussian_smooth.py index f7f69a5d6a..accd373761 100644 --- a/tests/test_gaussian_smooth.py +++ b/tests/test_gaussian_smooth.py @@ -11,6 +11,7 @@ import unittest +import torch import numpy as np from parameterized import parameterized @@ -24,7 +25,7 @@ [ {"sigma": 1.5}, p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore - np.array( + p( [ [ [0.59167546, 0.69312394, 0.59167546], @@ -45,7 +46,7 @@ [ {"sigma": 0.5}, p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore - np.array( + p( [ [ [0.8424794, 0.99864554, 0.8424794], @@ -66,7 +67,7 @@ [ {"sigma": [1.5, 0.5]}, p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore - np.array( + p( [ [ [0.8542037, 1.0125432, 0.8542037], @@ -88,7 +89,7 @@ class TestGaussianSmooth(unittest.TestCase): @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSmooth(**argments)(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + torch.testing.assert_allclose(result, expected_data, atol=0, rtol=1e-4) if __name__ == "__main__": From b8c52dc961a5bbe56649121d472ddc58393099e2 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 10:50:26 +0100 Subject: [PATCH 060/176] use torch.tensor Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 3193de55c4..9418feff26 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -579,9 +579,9 @@ def query_memory(n=2): return ",".join([f"{int(x)}" for x in ids]) -TEST_NDARRAYS: Tuple[Callable] = (np.array, torch.Tensor) # type: ignore +TEST_NDARRAYS: Tuple[Callable] = (np.array, torch.tensor) # type: ignore if torch.cuda.is_available(): - gpu_tensor: Callable = partial(torch.Tensor, device="cuda") + gpu_tensor: Callable = partial(torch.tensor, device="cuda") TEST_NDARRAYS = TEST_NDARRAYS + (gpu_tensor,) # type: ignore From ed8a42f52d98375d0967b9046699a848c60b929b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 11:52:02 +0100 Subject: [PATCH 061/176] zoom and randzoom Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 82 +++++++++++++++---------------- tests/test_rand_zoom.py | 82 +++++++++++++++++-------------- tests/test_zoom.py | 59 ++++++++++++---------- 3 files changed, 119 insertions(+), 104 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 8afc05078a..7fdaff31e9 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -22,7 +22,7 @@ from monai.config import USE_COMPILED, DtypeLike from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull -from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.croppad.array import CenterSpatialCrop, Pad from monai.transforms.transform import ( Randomizable, RandomizableTransform, @@ -519,7 +519,7 @@ def get_rotation_matrix(self) -> Optional[np.ndarray]: return self._rotation_matrix -class Zoom(ToDoTransform): +class Zoom(TorchTransform): """ Zooms an ND image using :py:class:`torch.nn.functional.interpolate`. For details, please see https://pytorch.org/docs/stable/nn.functional.html#interpolate. @@ -560,11 +560,11 @@ def __init__( def __call__( self, - img: np.ndarray, + img: DataObjects.Images, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, str]] = None, align_corners: Optional[bool] = None, - ): + ) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). @@ -580,31 +580,37 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate """ - _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim + img_t: torch.Tensor + img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + + _zoom = ensure_tuple_rep(self.zoom, img_t.ndim - 1) # match the spatial image dim zoomed = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, - input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), + input=img_t.float().unsqueeze(0), scale_factor=list(_zoom), mode=self.mode.value if mode is None else InterpolateMode(mode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) - zoomed = zoomed.squeeze(0).detach().cpu().numpy() - if not self.keep_size or np.allclose(img.shape, zoomed.shape): - return zoomed + zoomed = zoomed.squeeze(0) + + if self.keep_size and not np.allclose(img_t.shape, zoomed.shape): + + pad_vec = [[0, 0]] * len(img_t.shape) + slice_vec = [slice(None)] * len(img_t.shape) + for idx, (od, zd) in enumerate(zip(img_t.shape, zoomed.shape)): + diff = od - zd + half = abs(diff) // 2 + if diff > 0: # need padding + pad_vec[idx] = [half, diff - half] + elif diff < 0: # need slicing + slice_vec[idx] = slice(half, half + od) + + padder = Pad(pad_vec, padding_mode or self.padding_mode) + zoomed = padder(zoomed) + zoomed = zoomed[tuple(slice_vec)] - pad_vec = [[0, 0]] * len(img.shape) - slice_vec = [slice(None)] * len(img.shape) - for idx, (od, zd) in enumerate(zip(img.shape, zoomed.shape)): - diff = od - zd - half = abs(diff) // 2 - if diff > 0: # need padding - pad_vec[idx] = [half, diff - half] - elif diff < 0: # need slicing - slice_vec[idx] = slice(half, half + od) - padding_mode = self.padding_mode if padding_mode is None else NumpyPadMode(padding_mode) - zoomed = np.pad(zoomed, pad_vec, mode=padding_mode.value) - return zoomed[tuple(slice_vec)] + return self.post_convert_data(zoomed, orig_type, orig_device) class Rotate90(TorchOrNumpyTransform): @@ -837,7 +843,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return flipper(img) -class RandZoom(ToDoTransform, RandomizableTransform): +class RandZoom(TorchTransform, RandomizableTransform): """ Randomly zooms input arrays with given probability within given zoom range. @@ -894,11 +900,11 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, - img: np.ndarray, + img: DataObjects.Images, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, str]] = None, align_corners: Optional[bool] = None, - ) -> np.ndarray: + ) -> DataObjects.Images: """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). @@ -915,25 +921,17 @@ def __call__( """ # match the spatial image dim self.randomize() - _dtype = np.float32 if not self._do_transform: - return img.astype(_dtype) - if len(self._zoom) == 1: - # to keep the spatial shape ratio, use same random zoom factor for all dims - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) - elif len(self._zoom) == 2 and img.ndim > 3: - # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) - zoomer = Zoom(self._zoom, keep_size=self.keep_size) - return np.asarray( - zoomer( - img, - mode=mode or self.mode, - padding_mode=padding_mode or self.padding_mode, - align_corners=self.align_corners if align_corners is None else align_corners, - ), - dtype=_dtype, - ) + return img + if self._do_transform: + if len(self._zoom) == 1: + # to keep the spatial shape ratio, use same random zoom factor for all dims + self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) + elif len(self._zoom) == 2 and img.ndim > 3: + # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim + self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) + zoomer = Zoom(self._zoom, keep_size=self.keep_size, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode, align_corners=align_corners or self.align_corners) + return zoomer(img) class AffineGrid(ToDoTransform): diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index 35cf30bcb1..4f805a89b3 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -12,12 +12,13 @@ import unittest import numpy as np +import torch from parameterized import parameterized from scipy.ndimage import zoom as zoom_scipy from monai.transforms import RandZoom from monai.utils import GridSampleMode, InterpolateMode -from tests.utils import NumpyImageTestCase2D +from tests.utils import NumpyImageTestCase2D, TEST_NDARRAYS VALID_CASES = [(0.8, 1.2, "nearest", False), (0.8, 1.2, InterpolateMode.NEAREST, False)] @@ -25,29 +26,34 @@ class TestRandZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, min_zoom, max_zoom, mode, keep_size): - random_zoom = RandZoom( - prob=1.0, - min_zoom=min_zoom, - max_zoom=max_zoom, - mode=mode, - keep_size=keep_size, - ) - random_zoom.set_random_state(1234) - zoomed = random_zoom(self.imt[0]) - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(zoomed, expected, atol=1.0) + for p in TEST_NDARRAYS: + random_zoom = RandZoom( + prob=1.0, + min_zoom=min_zoom, + max_zoom=max_zoom, + mode=mode, + keep_size=keep_size, + ) + random_zoom.set_random_state(1234) + zoomed = random_zoom(p(self.imt[0])) + if isinstance(zoomed, torch.Tensor): + zoomed = zoomed.cpu().numpy() + expected = [] + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False)) + expected = np.stack(expected).astype(np.float32) + np.testing.assert_allclose(zoomed, expected, atol=1.0) def test_keep_size(self): - random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) - zoomed = random_zoom(self.imt[0]) - self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + for p in TEST_NDARRAYS: + im = p(self.imt[0]) + random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) + zoomed = random_zoom(im) + self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) @parameterized.expand( [ @@ -57,23 +63,25 @@ def test_keep_size(self): ] ) def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises): - with self.assertRaises(raises): - random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) - random_zoom(self.imt[0]) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode) + random_zoom(p(self.imt[0])) def test_auto_expand_3d(self): - random_zoom = RandZoom( - prob=1.0, - min_zoom=[0.8, 0.7], - max_zoom=[1.2, 1.3], - mode="nearest", - keep_size=False, - ) - random_zoom.set_random_state(1234) - test_data = np.random.randint(0, 2, size=[2, 2, 3, 4]) - zoomed = random_zoom(test_data) - np.testing.assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) - np.testing.assert_allclose(zoomed.shape, (2, 2, 3, 3)) + for p in TEST_NDARRAYS: + random_zoom = RandZoom( + prob=1.0, + min_zoom=[0.8, 0.7], + max_zoom=[1.2, 1.3], + mode="nearest", + keep_size=False, + ) + random_zoom.set_random_state(1234) + test_data = p(np.random.randint(0, 2, size=[2, 2, 3, 4])) + zoomed = random_zoom(test_data) + np.testing.assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) + np.testing.assert_allclose(zoomed.shape, (2, 2, 3, 3)) if __name__ == "__main__": diff --git a/tests/test_zoom.py b/tests/test_zoom.py index dcc401f16c..d623c06eb5 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -11,12 +11,13 @@ import unittest +import torch import numpy as np from parameterized import parameterized from scipy.ndimage import zoom as zoom_scipy from monai.transforms import Zoom -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D VALID_CASES = [(1.5, "nearest"), (1.5, "nearest"), (0.8, "bilinear"), (0.8, "area")] @@ -26,38 +27,46 @@ class TestZoom(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, zoom, mode): - zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False) - zoomed = zoom_fn(self.imt[0]) - _order = 0 - if mode.endswith("linear"): - _order = 1 - expected = [] - for channel in self.imt[0]: - expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) - expected = np.stack(expected).astype(np.float32) - np.testing.assert_allclose(zoomed, expected, atol=1.0) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False) + zoomed = zoom_fn(p(self.imt[0])) + _order = 0 + if mode.endswith("linear"): + _order = 1 + expected = [] + for channel in self.imt[0]: + expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) + expected = np.stack(expected).astype(np.float32) + if isinstance(zoomed, torch.Tensor): + zoomed = zoomed.detach().cpu().numpy() + np.testing.assert_allclose(zoomed, expected, atol=1.0) def test_keep_size(self): - zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True) - zoomed = zoom_fn(self.imt[0], mode="bilinear") - np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True) + zoomed = zoom_fn(p(self.imt[0]), mode="bilinear") + np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) - zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True) - zoomed = zoom_fn(self.imt[0]) - np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) + zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True) + zoomed = zoom_fn(p(self.imt[0])) + if isinstance(zoomed, torch.Tensor): + zoomed = zoomed.detach().cpu().numpy() + np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:]) @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, zoom, mode, raises): - with self.assertRaises(raises): - zoom_fn = Zoom(zoom=zoom, mode=mode) - zoom_fn(self.imt[0]) + for p in TEST_NDARRAYS: + with self.assertRaises(raises): + zoom_fn = Zoom(zoom=zoom, mode=mode) + zoom_fn(p(self.imt[0])) def test_padding_mode(self): - zoom_fn = Zoom(zoom=0.5, mode="nearest", padding_mode="constant", keep_size=True) - test_data = np.array([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]]) - zoomed = zoom_fn(test_data) - expected = np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]) - np.testing.assert_allclose(zoomed, expected) + for p in TEST_NDARRAYS: + zoom_fn = Zoom(zoom=0.5, mode="nearest", padding_mode="constant", keep_size=True) + test_data = p([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]]) + zoomed = zoom_fn(test_data) + expected = p([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]) + torch.testing.assert_allclose(zoomed, expected) if __name__ == "__main__": From 359fca1e683babe6db4cebfeb382633175b63c1f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 12:00:00 +0100 Subject: [PATCH 062/176] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 30 ++++++++++++-------- monai/transforms/utility/array.py | 2 +- tests/test_as_discrete.py | 46 +++++++++++++++++-------------- tests/test_divisible_pad.py | 24 +++++++++------- tests/test_gaussian_smooth.py | 1 - tests/test_rand_zoom.py | 2 +- tests/test_zoom.py | 2 +- 7 files changed, 61 insertions(+), 46 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 7fdaff31e9..7b26fe99e2 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -595,13 +595,13 @@ def __call__( if self.keep_size and not np.allclose(img_t.shape, zoomed.shape): - pad_vec = [[0, 0]] * len(img_t.shape) + pad_vec = [(0, 0)] * len(img_t.shape) slice_vec = [slice(None)] * len(img_t.shape) for idx, (od, zd) in enumerate(zip(img_t.shape, zoomed.shape)): diff = od - zd half = abs(diff) // 2 if diff > 0: # need padding - pad_vec[idx] = [half, diff - half] + pad_vec[idx] = (half, diff - half) elif diff < 0: # need slicing slice_vec[idx] = slice(half, half + od) @@ -609,7 +609,6 @@ def __call__( zoomed = padder(zoomed) zoomed = zoomed[tuple(slice_vec)] - return self.post_convert_data(zoomed, orig_type, orig_device) @@ -921,17 +920,24 @@ def __call__( """ # match the spatial image dim self.randomize() + if not self._do_transform: return img - if self._do_transform: - if len(self._zoom) == 1: - # to keep the spatial shape ratio, use same random zoom factor for all dims - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) - elif len(self._zoom) == 2 and img.ndim > 3: - # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) - zoomer = Zoom(self._zoom, keep_size=self.keep_size, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode, align_corners=align_corners or self.align_corners) - return zoomer(img) + + if len(self._zoom) == 1: + # to keep the spatial shape ratio, use same random zoom factor for all dims + self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) + elif len(self._zoom) == 2 and img.ndim > 3: + # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim + self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) + zoomer = Zoom( + self._zoom, + keep_size=self.keep_size, + mode=mode or self.mode, + padding_mode=padding_mode or self.padding_mode, + align_corners=align_corners or self.align_corners, + ) + return zoomer(img) class AffineGrid(ToDoTransform): diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 4ce05a4ea6..3641e16521 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -14,7 +14,6 @@ """ import logging -from monai.utils.misc import dtype_convert import sys import time import warnings @@ -35,6 +34,7 @@ from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices from monai.utils import ensure_tuple, min_version, optional_import from monai.utils.enums import DataObjects +from monai.utils.misc import dtype_convert PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index a2332f9274..e5e6bade41 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -19,26 +19,32 @@ TESTS = [] for p in TEST_NDARRAYS: - TESTS.append([ - {"argmax": True, "to_onehot": False, "n_classes": None, "threshold_values": False, "logit_thresh": 0.5}, - p([[[0.0, 1.0]], [[2.0, 3.0]]]), - p([[[1.0, 1.0]]]), - (1, 1, 2), - ]) - - TESTS.append([ - {"argmax": True, "to_onehot": True, "n_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, - p([[[0.0, 1.0]], [[2.0, 3.0]]]), - p([[[0.0, 0.0]], [[1.0, 1.0]]]), - (2, 1, 2), - ]) - - TESTS.append([ - {"argmax": False, "to_onehot": False, "n_classes": None, "threshold_values": True, "logit_thresh": 0.6}, - p([[[0.0, 1.0], [2.0, 3.0]]]), - p([[[0.0, 1.0], [1.0, 1.0]]]), - (1, 2, 2), - ]) + TESTS.append( + [ + {"argmax": True, "to_onehot": False, "n_classes": None, "threshold_values": False, "logit_thresh": 0.5}, + p([[[0.0, 1.0]], [[2.0, 3.0]]]), + p([[[1.0, 1.0]]]), + (1, 1, 2), + ] + ) + + TESTS.append( + [ + {"argmax": True, "to_onehot": True, "n_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, + p([[[0.0, 1.0]], [[2.0, 3.0]]]), + p([[[0.0, 0.0]], [[1.0, 1.0]]]), + (2, 1, 2), + ] + ) + + TESTS.append( + [ + {"argmax": False, "to_onehot": False, "n_classes": None, "threshold_values": True, "logit_thresh": 0.6}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + p([[[0.0, 1.0], [1.0, 1.0]]]), + (1, 2, 2), + ] + ) class TestAsDiscrete(unittest.TestCase): diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index 9ef096a1e8..cd02418d68 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -21,18 +21,22 @@ for p in TEST_NDARRAYS: # pad first dim to be divisible by 7, the second unchanged. - TESTS.append([ - {"k": (7, -1), "mode": "constant"}, - p(np.zeros((3, 8, 7))), - p(np.zeros((3, 14, 7))), - ]) + TESTS.append( + [ + {"k": (7, -1), "mode": "constant"}, + p(np.zeros((3, 8, 7))), + p(np.zeros((3, 14, 7))), + ] + ) # pad all dimensions to be divisible by 5 - TESTS.append([ - {"k": 5, "mode": "constant"}, - p(np.zeros((3, 10, 5, 17))), - p(np.zeros((3, 10, 5, 20))), - ]) + TESTS.append( + [ + {"k": 5, "mode": "constant"}, + p(np.zeros((3, 10, 5, 17))), + p(np.zeros((3, 10, 5, 20))), + ] + ) class TestDivisiblePad(unittest.TestCase): diff --git a/tests/test_gaussian_smooth.py b/tests/test_gaussian_smooth.py index accd373761..4dc7e35a9f 100644 --- a/tests/test_gaussian_smooth.py +++ b/tests/test_gaussian_smooth.py @@ -12,7 +12,6 @@ import unittest import torch -import numpy as np from parameterized import parameterized from monai.transforms import GaussianSmooth diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index 4f805a89b3..b0e280080a 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -18,7 +18,7 @@ from monai.transforms import RandZoom from monai.utils import GridSampleMode, InterpolateMode -from tests.utils import NumpyImageTestCase2D, TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D VALID_CASES = [(0.8, 1.2, "nearest", False), (0.8, 1.2, InterpolateMode.NEAREST, False)] diff --git a/tests/test_zoom.py b/tests/test_zoom.py index d623c06eb5..3875046ded 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -11,8 +11,8 @@ import unittest -import torch import numpy as np +import torch from parameterized import parameterized from scipy.ndimage import zoom as zoom_scipy From 4bb16caf7bb83a5fae9f458d8b2280f683dee29e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:22:21 +0100 Subject: [PATCH 063/176] fix std shift intensity test Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_std_shift_intensity.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_std_shift_intensity.py b/tests/test_std_shift_intensity.py index 1d385a54a4..63c50dd65c 100644 --- a/tests/test_std_shift_intensity.py +++ b/tests/test_std_shift_intensity.py @@ -11,6 +11,7 @@ import unittest +import torch import numpy as np from monai.transforms import ShiftIntensity, StdShiftIntensity @@ -27,18 +28,17 @@ def test_value(self): expected = shifter(imt) std_shifter = StdShiftIntensity(factor=factor) result = std_shifter(imt) - np.testing.assert_allclose(result, expected, rtol=1e-5) + torch.testing.assert_allclose(result, expected, atol=0, rtol=1e-5) def test_zerostd(self): - image = np.ones([2, 3, 3]) for p in TEST_NDARRAYS: - image = p(image) + image = p(np.ones([2, 3, 3], dtype=np.float32)) for nonzero in [True, False]: for channel_wise in [True, False]: factor = np.random.rand() std_shifter = StdShiftIntensity(factor=factor, nonzero=nonzero, channel_wise=channel_wise) result = std_shifter(image) - np.testing.assert_allclose(result, image, rtol=1e-5) + torch.testing.assert_allclose(result, image, atol=0, rtol=1e-5) def test_nonzero(self): for p in TEST_NDARRAYS: @@ -46,8 +46,8 @@ def test_nonzero(self): factor = np.random.rand() std_shifter = StdShiftIntensity(factor=factor, nonzero=True) result = std_shifter(image) - expected = np.asarray([[4 + factor, 0, 2 + factor], [0, 2 + factor, 4 + factor]]) - np.testing.assert_allclose(result, expected, rtol=1e-5) + expected = p(np.asarray([[4 + factor, 0, 2 + factor], [0, 2 + factor, 4 + factor]], dtype=np.float32)) + torch.testing.assert_allclose(result, expected, atol=0, rtol=1e-5) def test_channel_wise(self): for p in TEST_NDARRAYS: @@ -55,8 +55,8 @@ def test_channel_wise(self): factor = np.random.rand() std_shifter = StdShiftIntensity(factor=factor, channel_wise=True) result = std_shifter(image) - expected = np.stack((np.asarray([1 + 0.5 * factor, 2 + 0.5 * factor]), np.asarray([1, 1]))) - np.testing.assert_allclose(result, expected, rtol=1e-5) + expected = p(np.stack((np.asarray([1 + 0.5 * factor, 2 + 0.5 * factor]), np.asarray([1, 1]))).astype(np.float32)) + torch.testing.assert_allclose(result, expected, atol=0, rtol=1e-5) def test_dtype(self): trans_dtype = np.float32 From 86399b301c5f8ca65d3a44b61af69d17fe37a953 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:25:46 +0100 Subject: [PATCH 064/176] fix test_rotate90 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_rotate90.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 83959ba3a5..d8225228bc 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from monai.transforms import Rotate90 from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D @@ -25,8 +26,8 @@ def test_rotate90_default(self): expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.e-5, atol=1.e-8) def test_k(self): rotate = Rotate90(k=2) @@ -35,8 +36,8 @@ def test_k(self): expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.e-5, atol=1.e-8) def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) @@ -45,8 +46,8 @@ def test_spatial_axes(self): expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 1, (0, -1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.e-5, atol=1.e-8) def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) @@ -55,8 +56,8 @@ def test_prob_k_spatial_axes(self): expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 2, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.e-5, atol=1.e-8) if __name__ == "__main__": From bdc268763dac8fa10126c8c9972d33f9db3505ca Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:29:03 +0100 Subject: [PATCH 065/176] fix rand_gaussian_sharpen Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 2 +- tests/test_rand_gaussian_sharpen.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 5fa534877b..5bb2dcb9bc 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1057,7 +1057,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore img_t = img_t.to(torch.float) - gf1, gf2 = [GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx) for sigma in (self.sigma1, self.sigma2)] + gf1, gf2 = [GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device) for sigma in (self.sigma1, self.sigma2)] blurred_f = gf1(img_t.unsqueeze(0)) filter_blurred_f = gf2(blurred_f) out = (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0) diff --git a/tests/test_rand_gaussian_sharpen.py b/tests/test_rand_gaussian_sharpen.py index beb21f2acc..36dc2651a6 100644 --- a/tests/test_rand_gaussian_sharpen.py +++ b/tests/test_rand_gaussian_sharpen.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandGaussianSharpen @@ -24,7 +25,7 @@ [ {"prob": 1.0}, p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + p( [ [ [5.2919216, 5.5854445, 5.29192], @@ -53,7 +54,7 @@ "prob": 1.0, }, p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + p( [ [ [4.1071496, 3.597953, 4.1071477], @@ -82,7 +83,7 @@ "prob": 1.0, }, p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + p( [ [ [4.81077, 4.4237204, 4.81077], @@ -112,7 +113,7 @@ "prob": 1.0, }, p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + p( [ [ [4.430213, 3.2278745, 4.4302144], @@ -132,7 +133,7 @@ def test_value(self, argments, image, expected_data): converter = RandGaussianSharpen(**argments) converter.set_random_state(seed=0) result = converter(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + torch.testing.assert_allclose(result, expected_data, atol=0, rtol=1e-4) self.assertIsInstance(result, type(image)) From fd82b7e3fcfbc8c6295fdeb724a3de0925ec66ec Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:32:21 +0100 Subject: [PATCH 066/176] fix test_detect_envelope Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_detect_envelope.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py index 654183f01a..b82d993c10 100644 --- a/tests/test_detect_envelope.py +++ b/tests/test_detect_envelope.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import DetectEnvelope @@ -133,7 +134,7 @@ class TestDetectEnvelope(unittest.TestCase): def test_value(self, arguments, image, expected_data, atol): for p in TEST_NDARRAYS: result = DetectEnvelope(**arguments)(p(image)) - np.testing.assert_allclose(result, expected_data, atol=atol) + torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), atol=atol, rtol=1e-7) @parameterized.expand( [ From 32cc1f924e8141125d5f41f4a9bb24632934841e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:33:33 +0100 Subject: [PATCH 067/176] fix test_rand_std_shift_intensity Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_rand_std_shift_intensity.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_rand_std_shift_intensity.py b/tests/test_rand_std_shift_intensity.py index ae22596b67..a036734b2b 100644 --- a/tests/test_rand_std_shift_intensity.py +++ b/tests/test_rand_std_shift_intensity.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from monai.transforms import RandStdShiftIntensity from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D @@ -26,8 +27,8 @@ def test_value(self): np.random.seed(0) factor = np.random.uniform(low=-1.0, high=1.0) offset = factor * np.std(self.imt) - expected = self.imt + offset - np.testing.assert_allclose(result, expected, rtol=1e-5) + expected = p(self.imt + offset) + torch.testing.assert_allclose(result, expected, atol=0, rtol=1e-5) if __name__ == "__main__": From a2b9bd89e3fd2ba2906779e859c2c750d39d6cf2 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:35:05 +0100 Subject: [PATCH 068/176] fix test_scale_intensity Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_scale_intensity.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_scale_intensity.py b/tests/test_scale_intensity.py index 89e065cb6b..2c28cec73c 100644 --- a/tests/test_scale_intensity.py +++ b/tests/test_scale_intensity.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from monai.transforms import ScaleIntensity from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D @@ -25,15 +26,15 @@ def test_range_scale(self): mina = self.imt.min() maxa = self.imt.max() norm = (self.imt - mina) / (maxa - mina) - expected = (norm * (2.0 - 1.0)) + 1.0 - np.testing.assert_allclose(result, expected) + expected = p((norm * (2.0 - 1.0)) + 1.0) + torch.testing.assert_allclose(result, expected, rtol=1e-7, atol=0) def test_factor_scale(self): for p in TEST_NDARRAYS: scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1) result = scaler(p(self.imt)) - expected = (self.imt * (1 + 0.1)).astype(np.float32) - np.testing.assert_allclose(result, expected) + expected = p((self.imt * (1 + 0.1)).astype(np.float32)) + torch.testing.assert_allclose(result, expected, rtol=1e-7, atol=0) if __name__ == "__main__": From af28b99d94df1f201e2217016e0514594d226938 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:39:03 +0100 Subject: [PATCH 069/176] fix test_savitzky_golay_smooth Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_savitzky_golay_smooth.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index bd3be49f11..1821cdc5f9 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -13,6 +13,7 @@ import numpy as np from parameterized import parameterized +import torch from monai.transforms import SavitzkyGolaySmooth from tests.utils import TEST_NDARRAYS @@ -60,8 +61,9 @@ class TestSavitzkyGolaySmooth(unittest.TestCase): @parameterized.expand([TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH]) def test_value(self, arguments, image, expected_data, atol): - result = SavitzkyGolaySmooth(**arguments)(image) - np.testing.assert_allclose(result, expected_data, atol=atol) + for p in TEST_NDARRAYS: + result = SavitzkyGolaySmooth(**arguments)(p(image)) + torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol) class TestSavitzkyGolaySmoothREP(unittest.TestCase): @@ -69,4 +71,7 @@ class TestSavitzkyGolaySmoothREP(unittest.TestCase): def test_value(self, arguments, image, expected_data, atol): for p in TEST_NDARRAYS: result = SavitzkyGolaySmooth(**arguments)(p(image)) - np.testing.assert_allclose(result, expected_data, atol=atol) + torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol) + +if __name__ == "__main__": + unittest.main() From 3c35c9f8760eda5fe9557ee092d4793c1f3ab3c5 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:47:58 +0100 Subject: [PATCH 070/176] fix test_mask_intensity Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 2 +- monai/transforms/transform.py | 2 +- tests/test_mask_intensity.py | 21 +++++++++++---------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 5bb2dcb9bc..ff758a5f4c 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -851,7 +851,7 @@ def __call__(self, img: DataObjects.Images, mask_data: Optional[DataObjects.Imag "When mask_data is not single channel, mask_data channels must match img, " f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}." ) - mask_data_, *_ = convert_data_type(mask_data_, type(img)) + mask_data_, *_ = convert_data_type(mask_data_, type(img), device=img.device if isinstance(img, torch.Tensor) else None) return img * mask_data_ diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index d34f5d892b..d9c790c91e 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -70,7 +70,7 @@ def convert_data_type( data = data.astype(dtype) # type: ignore if isinstance(data, torch.Tensor) and device is not None: - data.to(device) + data = data.to(device) return data, orig_type, orig_device diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py index 5a33ef9171..31c0d7d6d8 100644 --- a/tests/test_mask_intensity.py +++ b/tests/test_mask_intensity.py @@ -13,6 +13,7 @@ import numpy as np from parameterized import parameterized +import torch from monai.transforms import MaskIntensity from tests.utils import TEST_NDARRAYS @@ -24,23 +25,23 @@ TEST_CASES.append( [ - {"mask_data": p([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])}, # type: ignore - q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore - np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), + {"mask_data": p([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])}, + q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + q([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), ] ) TEST_CASES.append( [ - {"mask_data": p([[[0, 0, 0], [0, 5, 0], [0, 0, 0]]])}, # type: ignore - q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore - np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), + {"mask_data": p([[[0, 0, 0], [0, 5, 0], [0, 0, 0]]])}, + q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + q([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]), ] ) TEST_CASES.append( [ - {"mask_data": p([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])}, # type: ignore - q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), # type: ignore - np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), + {"mask_data": p([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])}, + q([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), + q([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]), ] ) @@ -49,7 +50,7 @@ class TestMaskIntensity(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_value(self, argments, image, expected_data): result = MaskIntensity(**argments)(image) - np.testing.assert_allclose(result, expected_data) + torch.testing.assert_allclose(result, expected_data, rtol=1e-7, atol=0) if __name__ == "__main__": From f8f4b231de5cae87dc1933b49ffcb6dc5374bb21 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:49:09 +0100 Subject: [PATCH 071/176] fix test_rand_scale_intensity Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_rand_scale_intensity.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_rand_scale_intensity.py b/tests/test_rand_scale_intensity.py index 22fa0f4ce8..74a0be164e 100644 --- a/tests/test_rand_scale_intensity.py +++ b/tests/test_rand_scale_intensity.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from monai.transforms import RandScaleIntensity from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D @@ -24,8 +25,8 @@ def test_value(self): scaler.set_random_state(seed=0) result = scaler(p(self.imt)) np.random.seed(0) - expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32) - np.testing.assert_allclose(result, expected) + expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)) + torch.testing.assert_allclose(result, expected, rtol=1e-7, atol=0) if __name__ == "__main__": From 52c8ced74afb4a0b014648f38515b6db8542a2e9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:52:59 +0100 Subject: [PATCH 072/176] fix test_spatial_pad Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_spatial_pad.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 7db6c08b49..9b88e6f71e 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -14,6 +14,7 @@ import numpy as np from parameterized import parameterized +import torch from monai.transforms import SpatialPad from monai.utils.enums import NumpyPadMode @@ -79,15 +80,16 @@ def test_pad_shape(self, input_param, input_shape, expected_shape): for results in (results_1, results_2): np.testing.assert_allclose(results[-1].shape, expected_shape) if input_param["mode"] not in ("empty", NumpyPadMode.EMPTY): - np.testing.assert_allclose(results[0], results[-1], rtol=1e-5) + torch.testing.assert_allclose(p(results[0]), results[-1], atol=0, rtol=1e-5) def test_pad_kwargs(self): padder = SpatialPad( spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) ) - result = padder(np.zeros((3, 8, 4))) - np.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4))) - np.testing.assert_allclose(result[:, :, 4:], np.ones((3, 15, 4)) + 1) + for p in TEST_NDARRAYS: + result = padder(p(np.zeros((3, 8, 4)))) + torch.testing.assert_allclose(result[:, 8:, :4], p(np.ones((3, 7, 4))), rtol=1e-7, atol=0) + torch.testing.assert_allclose(result[:, :, 4:], p(np.ones((3, 15, 4)) + 1), rtol=1e-7, atol=0) if __name__ == "__main__": From ed8a82b18fc72cf2d81c2a8fdeb960713606e7dc Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 14:57:51 +0100 Subject: [PATCH 073/176] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 9 +++++++-- tests/test_mask_intensity.py | 3 +-- tests/test_rand_gaussian_sharpen.py | 1 - tests/test_rotate90.py | 8 ++++---- tests/test_savitzky_golay_smooth.py | 3 ++- tests/test_spatial_pad.py | 2 +- tests/test_std_shift_intensity.py | 6 ++++-- 7 files changed, 19 insertions(+), 13 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index ff758a5f4c..7e9eae5b78 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -851,7 +851,9 @@ def __call__(self, img: DataObjects.Images, mask_data: Optional[DataObjects.Imag "When mask_data is not single channel, mask_data channels must match img, " f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}." ) - mask_data_, *_ = convert_data_type(mask_data_, type(img), device=img.device if isinstance(img, torch.Tensor) else None) + mask_data_, *_ = convert_data_type( + mask_data_, type(img), device=img.device if isinstance(img, torch.Tensor) else None + ) return img * mask_data_ @@ -1057,7 +1059,10 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore img_t = img_t.to(torch.float) - gf1, gf2 = [GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device) for sigma in (self.sigma1, self.sigma2)] + gf1, gf2 = [ + GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device) + for sigma in (self.sigma1, self.sigma2) + ] blurred_f = gf1(img_t.unsqueeze(0)) filter_blurred_f = gf2(blurred_f) out = (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0) diff --git a/tests/test_mask_intensity.py b/tests/test_mask_intensity.py index 31c0d7d6d8..0086409c16 100644 --- a/tests/test_mask_intensity.py +++ b/tests/test_mask_intensity.py @@ -11,9 +11,8 @@ import unittest -import numpy as np -from parameterized import parameterized import torch +from parameterized import parameterized from monai.transforms import MaskIntensity from tests.utils import TEST_NDARRAYS diff --git a/tests/test_rand_gaussian_sharpen.py b/tests/test_rand_gaussian_sharpen.py index 36dc2651a6..f511555f21 100644 --- a/tests/test_rand_gaussian_sharpen.py +++ b/tests/test_rand_gaussian_sharpen.py @@ -11,7 +11,6 @@ import unittest -import numpy as np import torch from parameterized import parameterized diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index d8225228bc..7780976d3c 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -27,7 +27,7 @@ def test_rotate90_default(self): for channel in self.imt[0]: expected.append(np.rot90(channel, 1, (0, 1))) expected = p(np.stack(expected)) - torch.testing.assert_allclose(rotated, expected, rtol=1.e-5, atol=1.e-8) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_k(self): rotate = Rotate90(k=2) @@ -37,7 +37,7 @@ def test_k(self): for channel in self.imt[0]: expected.append(np.rot90(channel, 2, (0, 1))) expected = p(np.stack(expected)) - torch.testing.assert_allclose(rotated, expected, rtol=1.e-5, atol=1.e-8) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) @@ -47,7 +47,7 @@ def test_spatial_axes(self): for channel in self.imt[0]: expected.append(np.rot90(channel, 1, (0, -1))) expected = p(np.stack(expected)) - torch.testing.assert_allclose(rotated, expected, rtol=1.e-5, atol=1.e-8) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) @@ -57,7 +57,7 @@ def test_prob_k_spatial_axes(self): for channel in self.imt[0]: expected.append(np.rot90(channel, 2, (0, 1))) expected = p(np.stack(expected)) - torch.testing.assert_allclose(rotated, expected, rtol=1.e-5, atol=1.e-8) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) if __name__ == "__main__": diff --git a/tests/test_savitzky_golay_smooth.py b/tests/test_savitzky_golay_smooth.py index 1821cdc5f9..45d0ea3e4d 100644 --- a/tests/test_savitzky_golay_smooth.py +++ b/tests/test_savitzky_golay_smooth.py @@ -12,8 +12,8 @@ import unittest import numpy as np -from parameterized import parameterized import torch +from parameterized import parameterized from monai.transforms import SavitzkyGolaySmooth from tests.utils import TEST_NDARRAYS @@ -73,5 +73,6 @@ def test_value(self, arguments, image, expected_data, atol): result = SavitzkyGolaySmooth(**arguments)(p(image)) torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-7, atol=atol) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 9b88e6f71e..142c0c3e4e 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -13,8 +13,8 @@ from typing import List import numpy as np -from parameterized import parameterized import torch +from parameterized import parameterized from monai.transforms import SpatialPad from monai.utils.enums import NumpyPadMode diff --git a/tests/test_std_shift_intensity.py b/tests/test_std_shift_intensity.py index 63c50dd65c..5c16e14c45 100644 --- a/tests/test_std_shift_intensity.py +++ b/tests/test_std_shift_intensity.py @@ -11,8 +11,8 @@ import unittest -import torch import numpy as np +import torch from monai.transforms import ShiftIntensity, StdShiftIntensity from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D @@ -55,7 +55,9 @@ def test_channel_wise(self): factor = np.random.rand() std_shifter = StdShiftIntensity(factor=factor, channel_wise=True) result = std_shifter(image) - expected = p(np.stack((np.asarray([1 + 0.5 * factor, 2 + 0.5 * factor]), np.asarray([1, 1]))).astype(np.float32)) + expected = p( + np.stack((np.asarray([1 + 0.5 * factor, 2 + 0.5 * factor]), np.asarray([1, 1]))).astype(np.float32) + ) torch.testing.assert_allclose(result, expected, atol=0, rtol=1e-5) def test_dtype(self): From 8b53da5dcea94ed3724d017f121795e7e003a8fc Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 6 Jul 2021 16:32:57 +0100 Subject: [PATCH 074/176] set num workers=0 for mac Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_arraydataset.py | 4 ++-- tests/test_dataloader.py | 2 +- tests/test_grid_dataset.py | 2 +- tests/test_inverse.py | 2 +- tests/test_patch_dataset.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index f6459cc88c..bf16aacb26 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -125,7 +125,7 @@ def test_dataloading_img(self, img_transform, expected_shape): dataset = ArrayDataset(test_images, img_transform) self.assertEqual(len(dataset), 2) dataset.set_random_state(1234) - n_workers = 0 if sys.platform == "win32" else 2 + n_workers = 2 if sys.platform == "linux" else 0 loader = DataLoader(dataset, batch_size=10, num_workers=n_workers) imgs = next(iter(loader)) # test batching np.testing.assert_allclose(imgs.shape, [2] + list(expected_shape)) @@ -151,7 +151,7 @@ def test_dataloading_img_label(self, img_transform, expected_shape): dataset = ArrayDataset(test_images, img_transform, test_labels, img_transform) self.assertEqual(len(dataset), 2) dataset.set_random_state(1234) - n_workers = 0 if sys.platform == "win32" else 2 + n_workers = 2 if sys.platform == "linux" else 0 loader = DataLoader(dataset, batch_size=10, num_workers=n_workers) data = next(iter(loader)) # test batching np.testing.assert_allclose(data[0].shape, [2] + list(expected_shape)) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 3b159fb5b8..2c9f85fec3 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -48,7 +48,7 @@ def test_values(self): ] ) dataset = CacheDataset(data=datalist, transform=transform, cache_rate=0.5, cache_num=1) - n_workers = 0 if sys.platform == "win32" else 2 + n_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=n_workers) for d in dataloader: self.assertEqual(d["image"][0], "spleen_19.nii.gz") diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 6e0aa4023e..045a102955 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -36,7 +36,7 @@ def test_shape(self): test_dataset = ["vwxyz", "helloworld", "worldfoobar"] result = GridPatchDataset(dataset=test_dataset, patch_iter=identity_generator, with_coordinates=False) output = [] - n_workers = 0 if sys.platform == "win32" else 2 + n_workers = 2 if sys.platform == "linux" else 0 for item in DataLoader(result, batch_size=3, num_workers=n_workers): output.append("".join(item)) expected = ["vwx", "wor", "yzh", "ldf", "ell", "oob", "owo", "ar", "rld"] diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 70aecf222b..eb53743d12 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -634,7 +634,7 @@ def test_inverse_inferred_seg(self, extra_transform): batch_size = 10 # num workers = 0 for mac - num_workers = 2 if sys.platform != "darwin" else 0 + num_workers = 2 if sys.platform == "linux" else 0 transforms = Compose( [ AddChanneld(KEYS), diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index 4f6e9a25fd..119da0bd2c 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -32,7 +32,7 @@ def test_shape(self): result = PatchDataset(dataset=test_dataset, patch_func=identity, samples_per_image=n_per_image) output = [] - n_workers = 0 if sys.platform == "win32" else 2 + n_workers = 2 if sys.platform == "linux" else 0 for item in DataLoader(result, batch_size=3, num_workers=n_workers): output.append("".join(item)) expected = ["vwx", "yzh", "ell", "owo", "rld"] From b82c6f10a56efa1b7757cfa6796efd44b6a8f85a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 7 Jul 2021 09:53:02 +0100 Subject: [PATCH 075/176] fix test_divisible_pad Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_divisible_pad.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index cd02418d68..c51c8da625 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import DivisiblePad @@ -52,8 +53,8 @@ def test_pad_kwargs(self): padder = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) for p in TEST_NDARRAYS: result = padder(p(np.zeros((3, 8, 4)))) - np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4))) - np.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1) + torch.testing.assert_allclose(result[:, :1, :4], p(np.ones((3, 1, 4))), rtol=1e-7, atol=0) + torch.testing.assert_allclose(result[:, :, 4:5], p(np.ones((3, 10, 1)) + 1), rtol=1e-7, atol=0) if __name__ == "__main__": From 4c71aed7416200b5fb643a72e64ba7af079baca9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 7 Jul 2021 09:54:56 +0100 Subject: [PATCH 076/176] fix test_gaussian_sharpen Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_gaussian_sharpen.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_gaussian_sharpen.py b/tests/test_gaussian_sharpen.py index cdd4147c41..66fca04311 100644 --- a/tests/test_gaussian_sharpen.py +++ b/tests/test_gaussian_sharpen.py @@ -11,7 +11,7 @@ import unittest -import numpy as np +import torch from parameterized import parameterized from monai.transforms import GaussianSharpen @@ -24,7 +24,7 @@ [ {}, p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + p( [ [ [4.1081963, 3.4950666, 4.1081963], @@ -41,7 +41,7 @@ [ {"sigma1": 1.0, "sigma2": 0.75, "alpha": 20}, p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + p( [ [ [4.513644, 4.869134, 4.513644], @@ -62,7 +62,7 @@ [ {"sigma1": (0.5, 1.0), "sigma2": (0.5, 0.75), "alpha": 20}, p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array( + p( [ [ [3.3324685, 3.335536, 3.3324673], @@ -84,7 +84,7 @@ class TestGaussianSharpen(unittest.TestCase): @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = GaussianSharpen(**argments)(image) - np.testing.assert_allclose(result, expected_data, rtol=1e-4) + torch.testing.assert_allclose(result, expected_data, atol=0, rtol=1e-4) if __name__ == "__main__": From 87c6d83296c4442f73c6152ec50dd7df6eff7b6b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 7 Jul 2021 09:57:58 +0100 Subject: [PATCH 077/176] fix test_rand_rotate90 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_rand_rotate90.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index 55036d38cb..dededbba8d 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from monai.transforms import RandRotate90 from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D @@ -26,8 +27,8 @@ def test_default(self): expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_k(self): rotate = RandRotate90(max_k=2) @@ -37,8 +38,8 @@ def test_k(self): expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_spatial_axes(self): rotate = RandRotate90(spatial_axes=(0, 1)) @@ -48,8 +49,8 @@ def test_spatial_axes(self): expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 0, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) def test_prob_k_spatial_axes(self): rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1)) @@ -59,8 +60,8 @@ def test_prob_k_spatial_axes(self): expected = [] for channel in self.imt[0]: expected.append(np.rot90(channel, 1, (0, 1))) - expected = np.stack(expected) - self.assertTrue(np.allclose(rotated, expected)) + expected = p(np.stack(expected)) + torch.testing.assert_allclose(rotated, expected, rtol=1.0e-5, atol=1.0e-8) if __name__ == "__main__": From c1fe71b2fadd43adf0020ae75670e87c1044fc32 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 7 Jul 2021 10:41:32 +0100 Subject: [PATCH 078/176] fix test_crop_foreground Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_crop_foreground.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 4c51a274a0..7235527e3b 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import CropForeground @@ -24,7 +25,7 @@ [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), # type: ignore - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), ] ) @@ -32,7 +33,7 @@ [ {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, p([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), # type: ignore - np.array([[[3]]]), + p([[[3]]]), ] ) @@ -40,7 +41,7 @@ [ {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), # type: ignore - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), ] ) @@ -48,7 +49,7 @@ [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), # type: ignore - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), ] ) @@ -56,7 +57,7 @@ [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), # type: ignore - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), ] ) @@ -64,7 +65,7 @@ [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), # type: ignore - np.array([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), + p([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), ] ) @@ -72,7 +73,7 @@ [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10}, p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), # type: ignore - np.zeros((1, 0, 0)), + p(np.zeros((1, 0, 0), dtype=np.int64)), ] ) @@ -81,7 +82,7 @@ class TestCropForeground(unittest.TestCase): @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = CropForeground(**argments)(image) - np.testing.assert_allclose(result, expected_data) + torch.testing.assert_allclose(result, expected_data, rtol=1e-7, atol=0) @parameterized.expand([TESTS[0]]) def test_return_coords(self, argments, image, _): From 61d7510ad63e94a67a65e3c59d27805287752765 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 7 Jul 2021 12:22:06 +0100 Subject: [PATCH 079/176] fix gibbs noise Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 2 +- tests/test_gibbs_noise.py | 30 +++++++++-------- tests/test_gibbs_noised.py | 38 +++++++++++---------- tests/test_rand_gibbs_noise.py | 40 ++++++++++++---------- tests/test_rand_gibbs_noised.py | 51 +++++++++++++++-------------- 5 files changed, 86 insertions(+), 75 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 7e9eae5b78..9908324573 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1323,7 +1323,7 @@ def _apply_mask(self, k: DataObjects.Images, data_type: type) -> DataObjects.Ima mask = np.repeat(mask[None], k.shape[0], axis=0) if data_type is torch.Tensor: - mask = torch.Tensor(mask) + mask = torch.Tensor(mask).to(k.device) # type: ignore # apply binary mask out: DataObjects.Images = k * mask # type: ignore diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index c174e96657..2d587f2c56 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -19,11 +19,15 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoise from monai.utils.misc import set_determinism +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +has_torch_fft, _ = optional_import("torch.fft") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) class TestGibbsNoise(unittest.TestCase): @@ -35,36 +39,36 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, 4, 20, 0, 5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return input_type(im) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 0.8 t = GibbsNoise(alpha) out1 = t(deepcopy(im)) out2 = t(deepcopy(im)) - np.testing.assert_allclose(out1, out2) + torch.testing.assert_allclose(out1, out2, rtol=1e-7, atol=0) self.assertIsInstance(out1, type(im)) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 0.0 t = GibbsNoise(alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(im, out, atol=1e-2) + torch.testing.assert_allclose(im, out, atol=1e-2, rtol=1e-7) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = 1.0 t = GibbsNoise(alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(0 * im, out) + torch.testing.assert_allclose(0 * im, out, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index 40e430abc6..eeebf1d1c1 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -19,12 +19,16 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import GibbsNoised from monai.utils.misc import set_determinism +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +has_torch_fft, _ = optional_import("torch.fft") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, as_tensor_output, input_type)) KEYS = ["im", "label"] @@ -38,49 +42,47 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d - ims = create_test_image(*im_shape, 4, 20, 0, 5) - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return {k: v for k, v in zip(KEYS, ims)} + return {k: input_type(create_test_image(*im_shape, 4, 20, 0, 5)[0]) for k in KEYS} @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, as_tensor_output, input_type): + data = self.get_data(im_shape, input_type) alpha = 0.8 t = GibbsNoised(KEYS, alpha, as_tensor_output) out1 = t(deepcopy(data)) out2 = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(out1[k], out2[k]) + torch.testing.assert_allclose(out1[k], out2[k], rtol=1e-7, atol=0) self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, _, input_type): + data = self.get_data(im_shape, input_type) alpha = 0.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(data[k], out[k], atol=1e-2) + torch.testing.assert_allclose(data[k], out[k], atol=1e-2, rtol=1e-7) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, _, input_type): + data = self.get_data(im_shape, input_type) alpha = 1.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(0 * data[k], out[k]) + torch.testing.assert_allclose(0 * data[k], out[k], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_dict_matches(self, im_shape, _, input_type): + data = self.get_data(im_shape, input_type) data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} alpha = 1.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) - np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) + torch.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index 0f8dafdeb8..06980afbaa 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -19,11 +19,15 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoise from monai.utils.misc import set_determinism +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +has_torch_fft, _ = optional_import("torch.fft") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) class TestRandGibbsNoise(unittest.TestCase): @@ -35,50 +39,50 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, 4, 20, 0, 5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return input_type(im) @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_0_prob(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 1.0] t = RandGibbsNoise(0.0, alpha) out = t(im) - np.testing.assert_allclose(im, out) + torch.testing.assert_allclose(im, out, rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 0.8] t = RandGibbsNoise(1.0, alpha) t.set_random_state(42) out1 = t(deepcopy(im)) t.set_random_state(42) out2 = t(deepcopy(im)) - np.testing.assert_allclose(out1, out2) + torch.testing.assert_allclose(out1, out2, rtol=1e-7, atol=0) self.assertIsInstance(out1, type(im)) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.0, 0.0] t = RandGibbsNoise(1.0, alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(im, out, atol=1e-2) + torch.testing.assert_allclose(im, out, atol=1e-2, rtol=1e-7) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [1.0, 1.0] t = RandGibbsNoise(1.0, alpha) out = t(deepcopy(im)) - np.testing.assert_allclose(0 * im, out) + torch.testing.assert_allclose(0 * im, out, rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_alpha(self, im_shape, as_tensor_input): - im = self.get_data(im_shape, as_tensor_input) + def test_alpha(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) alpha = [0.5, 0.51] t = RandGibbsNoise(1.0, alpha) _ = t(deepcopy(im)) diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index a37dd7fa93..fa8820ce9a 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -19,12 +19,16 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandGibbsNoised from monai.utils.misc import set_determinism +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS + +has_torch_fft, _ = optional_import("torch.fft") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, as_tensor_output, input_type)) KEYS = ["im", "label"] @@ -38,24 +42,22 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d - ims = create_test_image(*im_shape, 4, 20, 0, 5) - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return {k: v for k, v in zip(KEYS, ims)} + return {k: input_type(create_test_image(*im_shape, 4, 20, 0, 5)[0]) for k in KEYS} @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_0_prob(self, im_shape, as_tensor_output, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 1.0] t = RandGibbsNoised(KEYS, 0.0, alpha, as_tensor_output) out = t(data) for k in KEYS: - np.testing.assert_allclose(data[k], out[k]) + torch.testing.assert_allclose(data[k], out[k], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_same_result(self, im_shape, as_tensor_output, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 0.8] t = RandGibbsNoised(KEYS, 1.0, alpha, as_tensor_output) t.set_random_state(42) @@ -63,45 +65,44 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): t.set_random_state(42) out2 = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(out1[k], out2[k]) + torch.testing.assert_allclose(out1[k], out2[k], rtol=1e-7, atol=0) self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_identity(self, im_shape, _, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.0, 0.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(data[k], out[k], atol=1e-2) + torch.testing.assert_allclose(data[k], out[k], atol=1e-2, rtol=1e-7) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha_1(self, im_shape, _, input_type): + data = self.get_data(im_shape, input_type) alpha = [1.0, 1.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) for k in KEYS: - np.testing.assert_allclose(0 * data[k], out[k]) + torch.testing.assert_allclose(0 * data[k], out[k], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_dict_matches(self, im_shape, _, input_type): + data = self.get_data(im_shape, input_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} alpha = [0.5, 1.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) - np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) + torch.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_alpha(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + def test_alpha(self, im_shape, _, input_type): + data = self.get_data(im_shape, input_type) alpha = [0.5, 0.51] t = RandGibbsNoised(KEYS, 1.0, alpha) _ = t(deepcopy(data)) - self.assertGreaterEqual(t.sampled_alpha, 0.5) - self.assertLessEqual(t.sampled_alpha, 0.51) + self.assertTrue(0.5 <= t.sampled_alpha <= 0.51) if __name__ == "__main__": From 801ff55964c82de42faf72279366d490001d7b47 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 7 Jul 2021 15:25:49 +0100 Subject: [PATCH 080/176] improve docstring Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index d9c790c91e..47212c20e4 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -46,7 +46,19 @@ def convert_data_type( device: Optional[torch.device] = None, dtype: Optional[Union[DtypeLike, torch.dtype]] = None, ) -> Tuple[DataObjects.Images, type, Optional[torch.device]]: - """Convert to torch.Tensor/np.ndarray.""" + """Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc. + + Args: + data: data to be converted + output_type: `torch.Tensor` or `np.ndarray` (if blank, unchanged) + device: if output is `torch.Tensor`, select device (if blank, unchanged) + dtype: dtype of output data. Converted to correct library type (e.g., + `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`). + If left blank, it remains unchanged. + + Returns: + modified data, orig_type, orig_device + """ orig_type = type(data) orig_device = data.device if isinstance(data, torch.Tensor) else None From 169f7b851649dee33bf70041ccb155ddf594232f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 7 Jul 2021 15:33:08 +0100 Subject: [PATCH 081/176] rotate as torch Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 30 +++--- tests/test_rand_rotate.py | 91 +++++++++-------- tests/test_rand_rotated.py | 164 +++++++++++++++++------------- tests/test_rotate.py | 77 +++++++------- tests/test_rotated.py | 50 +++++---- 5 files changed, 232 insertions(+), 180 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 7b26fe99e2..4757206791 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -31,6 +31,7 @@ TorchOrNumpyTransform, TorchTransform, Transform, + convert_data_type, ) from monai.transforms.utils import ( create_control_grid, @@ -409,7 +410,7 @@ def __call__( return np.asarray(resized) -class Rotate(ToDoTransform, ThreadUnsafe): +class Rotate(TorchTransform, ThreadUnsafe): """ Rotates an input image by given angle using :py:class:`monai.networks.layers.AffineTransform`. @@ -450,12 +451,12 @@ def __init__( def __call__( self, - img: np.ndarray, + img: DataObjects.Images, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, - ) -> np.ndarray: + ) -> DataObjects.Images: """ Args: img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. @@ -478,7 +479,10 @@ def __call__( """ _dtype = dtype or self.dtype or img.dtype - im_shape = np.asarray(img.shape[1:]) # spatial dimensions + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=_dtype) # type: ignore + + im_shape = np.asarray(img_t.shape[1:]) # spatial dimensions input_ndim = len(im_shape) if input_ndim not in (2, 3): raise ValueError(f"Unsupported img dimension: {input_ndim}, available options are [2, 3].") @@ -495,6 +499,8 @@ def __call__( output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int) shift_1 = create_translate(input_ndim, (-(output_shape - 1) / 2).tolist()) transform = shift @ transform @ shift_1 + transform_t: torch.Tensor + transform_t, *_ = convert_data_type(transform, torch.Tensor, dtype=_dtype, device=img_t.device) # type: ignore xform = AffineTransform( normalized=False, @@ -504,12 +510,10 @@ def __call__( reverse_indexing=True, ) output = xform( - torch.as_tensor(np.ascontiguousarray(img).astype(_dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), - spatial_size=output_shape, + img_t.unsqueeze(0), transform_t, spatial_size=output_shape, ) self._rotation_matrix = transform - return np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) + return self.post_convert_data(output.squeeze(0).float(), orig_type, orig_device) def get_rotation_matrix(self) -> Optional[np.ndarray]: """ @@ -681,7 +685,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return rotator(img) -class RandRotate(ToDoTransform, RandomizableTransform): +class RandRotate(TorchTransform, RandomizableTransform): """ Randomly rotate the input arrays. @@ -750,12 +754,12 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, - img: np.ndarray, + img: DataObjects.Images, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, - ) -> np.ndarray: + ) -> DataObjects.Images: """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). @@ -780,9 +784,9 @@ def __call__( mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode, align_corners=self.align_corners if align_corners is None else align_corners, - dtype=dtype or self.dtype or img.dtype, + dtype=dtype or self.dtype or img.dtype, # type: ignore ) - return np.array(rotator(img)) + return rotator(img) class RandFlip(TorchTransform, RandomizableTransform): diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 0ff8508a0f..4817e81735 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -10,25 +10,60 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import RandRotate -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) + TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -class TestRandRotate2D(NumpyImageTestCase2D): - @parameterized.expand( - [ - (np.pi / 2, True, "bilinear", "border", False), - (np.pi / 4, True, "nearest", "border", False), - (np.pi, False, "nearest", "zeros", True), - ((-np.pi / 4, 0), False, "nearest", "zeros", True), - ] +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append( + (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + "nearest", + "border", + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + "nearest", + "zeros", + True, + (1, 48, 64, 80), + ) ) - def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): + TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90))) + + +class TestRandRotate2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotate( range_x=degrees, prob=1.0, @@ -38,7 +73,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) _order = 0 if mode == "nearest" else 1 if mode == "border": @@ -52,38 +87,14 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated[0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRandRotate3D(NumpyImageTestCase3D): - @parameterized.expand( - [ - (np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - "nearest", - "border", - True, - (1, 89, 105, 104), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - "nearest", - "zeros", - True, - (1, 48, 64, 80), - ), - ((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)), - ] - ) - def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected): + @parameterized.expand(TEST_CASES_3D) + def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotate( range_x=x, range_y=y, @@ -95,8 +106,8 @@ def test_correct_results(self, x, y, z, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(self.imt[0]) - np.testing.assert_allclose(rotated.shape, expected) + rotated = rotate_fn(im_type(self.imt[0])) + torch.testing.assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py index 47b4b7107e..4c9a27f668 100644 --- a/tests/test_rand_rotated.py +++ b/tests/test_rand_rotated.py @@ -10,26 +10,104 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import RandRotated from monai.utils import GridSampleMode, GridSamplePadMode -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 2, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "nearest", "border", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True)) + TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True)) -class TestRandRotated2D(NumpyImageTestCase2D): - @parameterized.expand( - [ - (np.pi / 2, True, "bilinear", "border", False), - (np.pi / 4, True, "nearest", "border", False), - (np.pi, False, "nearest", "zeros", True), - ((-np.pi / 4, 0), False, "nearest", "zeros", True), - ] + +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append( + (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)) ) - def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_corners): + TEST_CASES_3D.append( + ( + p, + np.pi / 2, + -np.pi / 6, + (0.0, np.pi), + False, + GridSampleMode.NEAREST, + GridSamplePadMode.BORDER, + False, + (1, 87, 104, 109), + ) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + "nearest", + "border", + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + np.pi / 4, + (-np.pi / 9, np.pi / 4.5), + (np.pi / 9, np.pi / 6), + False, + GridSampleMode.NEAREST, + GridSamplePadMode.BORDER, + True, + (1, 89, 105, 104), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + "nearest", + "zeros", + True, + (1, 48, 64, 80), + ) + ) + TEST_CASES_3D.append( + ( + p, + 0.0, + (2 * np.pi, 2.06 * np.pi), + (-np.pi / 180, np.pi / 180), + True, + GridSampleMode.NEAREST, + GridSamplePadMode.ZEROS, + True, + (1, 48, 64, 80), + ) + ) + TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90))) + TEST_CASES_3D.append( + (p, (-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90)) + ) + + +class TestRandRotated2D(NumpyImageTestCase2D): + @parameterized.expand(TEST_CASES_2D) + def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners): rotate_fn = RandRotated( "img", range_x=degrees, @@ -40,7 +118,7 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) _order = 0 if mode == "nearest" else 1 if padding_mode == "border": @@ -53,70 +131,16 @@ def test_correct_results(self, degrees, keep_size, mode, padding_mode, align_cor expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v expected = np.stack(expected).astype(np.float32) good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRandRotated3D(NumpyImageTestCase3D): - @parameterized.expand( - [ - (np.pi / 2, -np.pi / 6, (0.0, np.pi), False, "bilinear", "border", False, (1, 87, 104, 109)), - ( - np.pi / 2, - -np.pi / 6, - (0.0, np.pi), - False, - GridSampleMode.NEAREST, - GridSamplePadMode.BORDER, - False, - (1, 87, 104, 109), - ), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - "nearest", - "border", - True, - (1, 89, 105, 104), - ), - ( - np.pi / 4, - (-np.pi / 9, np.pi / 4.5), - (np.pi / 9, np.pi / 6), - False, - GridSampleMode.NEAREST, - GridSamplePadMode.BORDER, - True, - (1, 89, 105, 104), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - "nearest", - "zeros", - True, - (1, 48, 64, 80), - ), - ( - 0.0, - (2 * np.pi, 2.06 * np.pi), - (-np.pi / 180, np.pi / 180), - True, - GridSampleMode.NEAREST, - GridSamplePadMode.ZEROS, - True, - (1, 48, 64, 80), - ), - ((-np.pi / 4, 0), 0, 0, False, "nearest", "zeros", False, (1, 48, 77, 90)), - ((-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 48, 77, 90)), - ] - ) - def test_correct_shapes(self, x, y, z, keep_size, mode, padding_mode, align_corners, expected): + @parameterized.expand(TEST_CASES_3D) + def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected): rotate_fn = RandRotated( "img", range_x=x, @@ -129,7 +153,7 @@ def test_correct_shapes(self, x, y, z, keep_size, mode, padding_mode, align_corn align_corners=align_corners, ) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) np.testing.assert_allclose(rotated["img"].shape, expected) diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 436c952d4b..16a9c6d124 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -10,42 +10,44 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import Rotate -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D - -TEST_CASES_2D = [ - (np.pi / 6, False, "bilinear", "border", False), - (np.pi / 4, True, "bilinear", "border", False), - (-np.pi / 4.5, True, "nearest", "reflection", False), - (np.pi, False, "nearest", "zeros", False), - (-np.pi / 2, False, "bilinear", "zeros", True), -] - -TEST_CASES_3D = [ - (-np.pi / 2, True, "nearest", "border", False), - (np.pi / 4, True, "bilinear", "border", False), - (-np.pi / 4.5, True, "nearest", "reflection", False), - (np.pi, False, "nearest", "zeros", False), - (-np.pi / 2, False, "bilinear", "zeros", False), -] - -TEST_CASES_SHAPE_3D = [ - ([-np.pi / 2, 1.0, 2.0], "nearest", "border", False), - ([np.pi / 4, 0, 0], "bilinear", "border", False), - ([-np.pi / 4.5, -20, 20], "nearest", "reflection", False), -] +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D + +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, -np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", False)) + TEST_CASES_2D.append((p, -np.pi / 2, False, "bilinear", "zeros", True)) + +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append((p, -np.pi / 2, True, "nearest", "border", False)) + TEST_CASES_3D.append((p, np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_3D.append((p, -np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_3D.append((p, np.pi, False, "nearest", "zeros", False)) + TEST_CASES_3D.append((p, -np.pi / 2, False, "bilinear", "zeros", False)) + +TEST_CASES_SHAPE_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_SHAPE_3D.append((p, [-np.pi / 2, 1.0, 2.0], "nearest", "border", False)) + TEST_CASES_SHAPE_3D.append((p, [np.pi / 4, 0, 0], "bilinear", "border", False)) + TEST_CASES_SHAPE_3D.append((p, [-np.pi / 4.5, -20, 20], "nearest", "reflection", False)) class TestRotate2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotate(angle, keep_size, mode, padding_mode, align_corners) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -70,15 +72,16 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne ) ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") class TestRotate3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotate([angle, 0, 0], keep_size, mode, padding_mode, align_corners) - rotated = rotate_fn(self.imt[0]) + rotated = rotate_fn(im_type(self.imt[0])) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated.shape) _order = 0 if mode == "nearest" else 1 @@ -103,23 +106,25 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne ) ) expected = np.stack(expected).astype(np.float32) + rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated n_good = np.sum(np.isclose(expected, rotated, atol=1e-3)) self.assertLessEqual(expected.size - n_good, 5, "diff at most 5 pixels") @parameterized.expand(TEST_CASES_SHAPE_3D) - def test_correct_shape(self, angle, mode, padding_mode, align_corners): + def test_correct_shape(self, im_type, angle, mode, padding_mode, align_corners): rotate_fn = Rotate(angle, True, align_corners=align_corners) - rotated = rotate_fn(self.imt[0], mode=mode, padding_mode=padding_mode) + rotated = rotate_fn(im_type(self.imt[0]), mode=mode, padding_mode=padding_mode) np.testing.assert_allclose(self.imt[0].shape, rotated.shape) def test_ill_case(self): - rotate_fn = Rotate(10, True) - with self.assertRaises(ValueError): # wrong shape - rotate_fn(self.imt) - - rotate_fn = Rotate(10, keep_size=False) - with self.assertRaises(ValueError): # wrong mode - rotate_fn(self.imt[0], mode="trilinear") + for p in TEST_NDARRAYS: + rotate_fn = Rotate(10, True) + with self.assertRaises(ValueError): # wrong shape + rotate_fn(p(self.imt)) + + rotate_fn = Rotate(10, keep_size=False) + with self.assertRaises(ValueError): # wrong mode + rotate_fn(p(self.imt[0]), mode="trilinear") if __name__ == "__main__": diff --git a/tests/test_rotated.py b/tests/test_rotated.py index 2ea421101b..cd27dd5406 100644 --- a/tests/test_rotated.py +++ b/tests/test_rotated.py @@ -10,36 +10,38 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import scipy.ndimage +import torch from parameterized import parameterized from monai.transforms import Rotated -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D -TEST_CASES_2D = [ - (-np.pi / 6, False, "bilinear", "border", False), - (-np.pi / 4, True, "bilinear", "border", False), - (np.pi / 4.5, True, "nearest", "reflection", False), - (-np.pi, False, "nearest", "zeros", False), - (np.pi / 2, False, "bilinear", "zeros", True), -] +TEST_CASES_2D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_2D.append((p, -np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_2D.append((p, -np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_2D.append((p, np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_2D.append((p, -np.pi, False, "nearest", "zeros", False)) + TEST_CASES_2D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) -TEST_CASES_3D = [ - (-np.pi / 6, False, "bilinear", "border", False), - (-np.pi / 4, True, "bilinear", "border", False), - (np.pi / 4.5, True, "nearest", "reflection", False), - (-np.pi, False, "nearest", "zeros", False), - (np.pi / 2, False, "bilinear", "zeros", True), -] +TEST_CASES_3D: List[Tuple] = [] +for p in TEST_NDARRAYS: + TEST_CASES_3D.append((p, -np.pi / 6, False, "bilinear", "border", False)) + TEST_CASES_3D.append((p, -np.pi / 4, True, "bilinear", "border", False)) + TEST_CASES_3D.append((p, np.pi / 4.5, True, "nearest", "reflection", False)) + TEST_CASES_3D.append((p, -np.pi, False, "nearest", "zeros", False)) + TEST_CASES_3D.append((p, np.pi / 2, False, "bilinear", "zeros", True)) class TestRotated2D(NumpyImageTestCase2D): @parameterized.expand(TEST_CASES_2D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), angle, keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -52,6 +54,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") @@ -64,9 +68,9 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne class TestRotated3D(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), [0, angle, 0], keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -79,6 +83,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], np.rad2deg(angle), (0, 2), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected.astype(np.float32), rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 voxels.") @@ -91,9 +97,9 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne class TestRotated3DXY(NumpyImageTestCase3D): @parameterized.expand(TEST_CASES_3D) - def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corners): + def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners): rotate_fn = Rotated(("img", "seg"), [0, 0, angle], keep_size, (mode, "nearest"), padding_mode, align_corners) - rotated = rotate_fn({"img": self.imt[0], "seg": self.segn[0]}) + rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -106,6 +112,8 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + for k, v in rotated.items(): + rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 voxels") From c9dd886a4a87647b7e61926d91666d09e7ceb654 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 12 Jul 2021 12:42:50 +0100 Subject: [PATCH 082/176] post merge fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 4 +++- monai/transforms/utility/array.py | 18 +++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 4757206791..ee6a77d699 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -510,7 +510,9 @@ def __call__( reverse_indexing=True, ) output = xform( - img_t.unsqueeze(0), transform_t, spatial_size=output_shape, + img_t.unsqueeze(0), + transform_t, + spatial_size=output_shape, ) self._rotation_matrix = transform return self.post_convert_data(output.squeeze(0).float(), orig_type, orig_device) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index fb4cb41469..c0b721073f 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -22,16 +22,6 @@ import numpy as np import torch -from monai.config import DtypeLike, NdarrayTensor -from monai.transforms.transform import Randomizable, Transform -from monai.transforms.utils import ( - convert_to_numpy, - convert_to_tensor, - extreme_points_to_image, - get_extreme_points, - map_binary_to_indices, -) -from monai.utils import ensure_tuple, issequenceiterable, min_version, optional_import from monai.config import DtypeLike from monai.transforms.transform import ( NumpyTransform, @@ -41,7 +31,13 @@ Transform, convert_data_type, ) -from monai.transforms.utils import extreme_points_to_image, get_extreme_points, map_binary_to_indices +from monai.transforms.utils import ( + convert_to_numpy, + convert_to_tensor, + extreme_points_to_image, + get_extreme_points, + map_binary_to_indices, +) from monai.utils import ensure_tuple, min_version, optional_import from monai.utils.enums import DataObjects from monai.utils.misc import dtype_convert From f0e396c821a46d0dbe6269f086e96db78bcf0c0e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 12 Jul 2021 15:37:28 +0100 Subject: [PATCH 083/176] fix test_to_numpy Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index c0b721073f..2584682f69 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -376,14 +376,12 @@ def __call__(self, img: Union[DataObjects.Images, Sequence]) -> np.ndarray: """ Apply the transform to `img` and make it contiguous. """ - if isinstance(img, Sequence): - img = np.array(img) + array: np.ndarray + if not isinstance(img, Sequence): + array, *_ = convert_data_type(img, np.ndarray) else: - img, *_ = convert_data_type(img, np.ndarray) - img = np.ascontiguousarray(img) - return img + array = np.asarray(img) - array: np.ndarray = np.asarray(img) return np.ascontiguousarray(array) if array.ndim > 0 else array From 82e944ddf6c4e19ff647832bb0b1e36f85888145 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 12 Jul 2021 15:56:35 +0100 Subject: [PATCH 084/176] skip torch tests if no torch.fft Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_rand_gibbs_noise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index 06980afbaa..fef6f29ce3 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -22,7 +22,7 @@ from monai.utils.module import optional_import from tests.utils import TEST_NDARRAYS -has_torch_fft, _ = optional_import("torch.fft") +_, has_torch_fft = optional_import("torch.fft") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): From 279579d91cb4ff088d7bb7398337f8fcaeba5afc Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 12 Jul 2021 15:57:40 +0100 Subject: [PATCH 085/176] same for noised Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_rand_gibbs_noised.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index fa8820ce9a..723256627a 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -22,7 +22,7 @@ from monai.utils.module import optional_import from tests.utils import TEST_NDARRAYS -has_torch_fft, _ = optional_import("torch.fft") +_, has_torch_fft = optional_import("torch.fft") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): From 595087ed81aa1332ef948bb626228cf21e703f56 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 12 Jul 2021 16:52:58 +0100 Subject: [PATCH 086/176] fix test_as_channel_firstd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 11 +++++++---- tests/test_as_channel_firstd.py | 13 +++++++------ 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2584682f69..652d4aa9df 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -90,7 +90,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img -class AsChannelFirst(TorchTransform): +class AsChannelFirst(TorchOrNumpyTransform): """ Change the channel dimension of the image to the first dimension. @@ -115,9 +115,12 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - img, orig_type, orig_device = self.pre_conv_data(img) - img = torch.moveaxis(img, self.channel_dim, 0) # type: ignore - return self.post_convert_data(img, orig_type, orig_device) + if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): + return torch.moveaxis(img, self.channel_dim, 0) + + img, orig_type, _ = convert_data_type(img, np.ndarray) + img = np.moveaxis(img, self.channel_dim, 0) + return convert_data_type(img, orig_type)[0] class AsChannelLast(TorchTransform): diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index e70c2e1b47..c15fae198b 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from tests.utils import TEST_NDARRAYS import unittest import numpy as np @@ -16,15 +17,15 @@ from monai.transforms import AsChannelFirstd -TEST_CASE_1 = [{"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)] - -TEST_CASE_2 = [{"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)] - -TEST_CASE_3 = [{"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([{"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)]) + TESTS.append([{"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)]) + TESTS.append([{"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)]) class TestAsChannelFirstd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_shape(self, input_param, expected_shape): test_data = { "image": np.random.randint(0, 2, size=[1, 2, 3, 4]), From a6fe6f1f93e21a896e70845a72bb03c8e7e96e9b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 12 Jul 2021 16:54:05 +0100 Subject: [PATCH 087/176] fix test_gibbs_noise(d) Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_gibbs_noise.py | 2 +- tests/test_gibbs_noised.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index 2d587f2c56..68a0a02d6d 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -22,7 +22,7 @@ from monai.utils.module import optional_import from tests.utils import TEST_NDARRAYS -has_torch_fft, _ = optional_import("torch.fft") +_, has_torch_fft = optional_import("torch.fft") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index eeebf1d1c1..d8b36b30f7 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -22,7 +22,7 @@ from monai.utils.module import optional_import from tests.utils import TEST_NDARRAYS -has_torch_fft, _ = optional_import("torch.fft") +_, has_torch_fft = optional_import("torch.fft") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): From 7992f2913fcf67d67fcbfab30596fde7d371f752 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 12 Jul 2021 17:01:55 +0100 Subject: [PATCH 088/176] channel first and last Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 15 +++++++++------ tests/test_as_channel_first.py | 17 +++++++++-------- tests/test_as_channel_firstd.py | 14 +++++++------- tests/test_as_channel_last.py | 17 +++++++++-------- tests/test_as_channel_lastd.py | 21 +++++++++++---------- 5 files changed, 45 insertions(+), 39 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 652d4aa9df..7480c7ff8c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -118,12 +118,12 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): return torch.moveaxis(img, self.channel_dim, 0) - img, orig_type, _ = convert_data_type(img, np.ndarray) + img, orig_type, orig_device = convert_data_type(img, np.ndarray) img = np.moveaxis(img, self.channel_dim, 0) - return convert_data_type(img, orig_type)[0] + return convert_data_type(img, orig_type, orig_device)[0] -class AsChannelLast(TorchTransform): +class AsChannelLast(TorchOrNumpyTransform): """ Change the channel dimension of the image to the last dimension. @@ -147,9 +147,12 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - img, orig_type, orig_device = self.pre_conv_data(img) - img = torch.moveaxis(img, self.channel_dim, -1) # type: ignore - return self.post_convert_data(img, orig_type, orig_device) + if isinstance(img, torch.Tensor) and hasattr(torch, "moveaxis"): + return torch.moveaxis(img, self.channel_dim, -1) + + img, orig_type, orig_device = convert_data_type(img, np.ndarray) + img = np.moveaxis(img, self.channel_dim, -1) + return convert_data_type(img, orig_type, orig_device)[0] class AddChannel(TorchOrNumpyTransform): diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index e7d9866ae1..1b2b5b87b8 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from tests.utils import TEST_NDARRAYS import unittest import numpy as np @@ -16,17 +17,17 @@ from monai.transforms import AsChannelFirst -TEST_CASE_1 = [{"channel_dim": -1}, (4, 1, 2, 3)] - -TEST_CASE_2 = [{"channel_dim": 3}, (4, 1, 2, 3)] - -TEST_CASE_3 = [{"channel_dim": 2}, (3, 1, 2, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"channel_dim": -1}, (4, 1, 2, 3)]) + TESTS.append([p, {"channel_dim": 3}, (4, 1, 2, 3)]) + TESTS.append([p, {"channel_dim": 2}, (3, 1, 2, 4)]) class TestAsChannelFirst(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): - test_data = np.random.randint(0, 2, size=[1, 2, 3, 4]) + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): + test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelFirst(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index c15fae198b..f491ca6772 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -19,18 +19,18 @@ TESTS = [] for p in TEST_NDARRAYS: - TESTS.append([{"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)]) - TESTS.append([{"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)]) - TESTS.append([{"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)]) class TestAsChannelFirstd(unittest.TestCase): @parameterized.expand(TESTS) - def test_shape(self, input_param, expected_shape): + def test_shape(self, in_type, input_param, expected_shape): test_data = { - "image": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "label": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "extra": np.random.randint(0, 2, size=[1, 2, 3, 4]), + "image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), } result = AsChannelFirstd(**input_param)(test_data) self.assertTupleEqual(result["image"].shape, expected_shape) diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index 6ec6c8d6e6..dc601251da 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from tests.utils import TEST_NDARRAYS import unittest import numpy as np @@ -16,17 +17,17 @@ from monai.transforms import AsChannelLast -TEST_CASE_1 = [{"channel_dim": 0}, (2, 3, 4, 1)] - -TEST_CASE_2 = [{"channel_dim": 1}, (1, 3, 4, 2)] - -TEST_CASE_3 = [{"channel_dim": 3}, (1, 2, 3, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"channel_dim": 0}, (2, 3, 4, 1)]) + TESTS.append([p, {"channel_dim": 1}, (1, 3, 4, 2)]) + TESTS.append([p, {"channel_dim": 3}, (1, 2, 3, 4)]) class TestAsChannelLast(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): - test_data = np.random.randint(0, 2, size=[1, 2, 3, 4]) + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): + test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelLast(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index 2ef4dd4da1..8e1ab5cc88 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from tests.utils import TEST_NDARRAYS import unittest import numpy as np @@ -16,20 +17,20 @@ from monai.transforms import AsChannelLastd -TEST_CASE_1 = [{"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)] - -TEST_CASE_2 = [{"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)] - -TEST_CASE_3 = [{"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)]) class TestAsChannelLastd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_param, expected_shape): + @parameterized.expand(TESTS) + def test_shape(self, in_type, input_param, expected_shape): test_data = { - "image": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "label": np.random.randint(0, 2, size=[1, 2, 3, 4]), - "extra": np.random.randint(0, 2, size=[1, 2, 3, 4]), + "image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), + "extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])), } result = AsChannelLastd(**input_param)(test_data) self.assertTupleEqual(result["image"].shape, expected_shape) From 2cf20dbfd089e48269100db349a54e09a9de9073 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 12 Jul 2021 17:22:33 +0100 Subject: [PATCH 089/176] isort/mypy Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 6 +++--- tests/test_as_channel_first.py | 2 +- tests/test_as_channel_firstd.py | 2 +- tests/test_as_channel_last.py | 2 +- tests/test_as_channel_lastd.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 7480c7ff8c..ba00cdde32 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -119,7 +119,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return torch.moveaxis(img, self.channel_dim, 0) img, orig_type, orig_device = convert_data_type(img, np.ndarray) - img = np.moveaxis(img, self.channel_dim, 0) + img = np.moveaxis(img, self.channel_dim, 0) # type: ignore return convert_data_type(img, orig_type, orig_device)[0] @@ -151,7 +151,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return torch.moveaxis(img, self.channel_dim, -1) img, orig_type, orig_device = convert_data_type(img, np.ndarray) - img = np.moveaxis(img, self.channel_dim, -1) + img = np.moveaxis(img, self.channel_dim, -1) # type: ignore return convert_data_type(img, orig_type, orig_device)[0] @@ -384,7 +384,7 @@ def __call__(self, img: Union[DataObjects.Images, Sequence]) -> np.ndarray: """ array: np.ndarray if not isinstance(img, Sequence): - array, *_ = convert_data_type(img, np.ndarray) + array, *_ = convert_data_type(img, np.ndarray) # type: ignore else: array = np.asarray(img) diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index 1b2b5b87b8..bc9158f277 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -9,13 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.utils import TEST_NDARRAYS import unittest import numpy as np from parameterized import parameterized from monai.transforms import AsChannelFirst +from tests.utils import TEST_NDARRAYS TESTS = [] for p in TEST_NDARRAYS: diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index f491ca6772..68d33434c1 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -9,13 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.utils import TEST_NDARRAYS import unittest import numpy as np from parameterized import parameterized from monai.transforms import AsChannelFirstd +from tests.utils import TEST_NDARRAYS TESTS = [] for p in TEST_NDARRAYS: diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index dc601251da..55a7a08676 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -9,13 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.utils import TEST_NDARRAYS import unittest import numpy as np from parameterized import parameterized from monai.transforms import AsChannelLast +from tests.utils import TEST_NDARRAYS TESTS = [] for p in TEST_NDARRAYS: diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index 8e1ab5cc88..350f639f3f 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -9,13 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.utils import TEST_NDARRAYS import unittest import numpy as np from parameterized import parameterized from monai.transforms import AsChannelLastd +from tests.utils import TEST_NDARRAYS TESTS = [] for p in TEST_NDARRAYS: From 07008814f93447eaf71f486d82daa2d15091ce6f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 13 Jul 2021 10:17:36 +0100 Subject: [PATCH 090/176] require fftshift Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_gibbs_noise.py | 2 +- tests/test_gibbs_noised.py | 2 +- tests/test_rand_gibbs_noise.py | 2 +- tests/test_rand_gibbs_noised.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index 68a0a02d6d..ae47b7e872 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -22,7 +22,7 @@ from monai.utils.module import optional_import from tests.utils import TEST_NDARRAYS -_, has_torch_fft = optional_import("torch.fft") +_, has_torch_fft = optional_import("torch.fft.fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index d8b36b30f7..4ad242051d 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -22,7 +22,7 @@ from monai.utils.module import optional_import from tests.utils import TEST_NDARRAYS -_, has_torch_fft = optional_import("torch.fft") +_, has_torch_fft = optional_import("torch.fft.fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index fef6f29ce3..943b71ebc0 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -22,7 +22,7 @@ from monai.utils.module import optional_import from tests.utils import TEST_NDARRAYS -_, has_torch_fft = optional_import("torch.fft") +_, has_torch_fft = optional_import("torch.fft.fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index 723256627a..bfd0d759c3 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -22,7 +22,7 @@ from monai.utils.module import optional_import from tests.utils import TEST_NDARRAYS -_, has_torch_fft = optional_import("torch.fft") +_, has_torch_fft = optional_import("torch.fft.fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): From fc4cc7f91465e509c40daf70b620abf546e66ea2 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 13 Jul 2021 11:30:37 +0100 Subject: [PATCH 091/176] num_workers=0 on windows Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_decollate.py | 2 +- tests/test_handler_transform_inverter.py | 2 +- tests/test_inverse_collation.py | 2 +- tests/test_invertd.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index f958a9ed04..235862ff96 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -121,7 +121,7 @@ def check_match(self, in1, in2): def check_decollate(self, dataset): batch_size = 2 # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) diff --git a/tests/test_handler_transform_inverter.py b/tests/test_handler_transform_inverter.py index f2e75a7153..9efcab68ba 100644 --- a/tests/test_handler_transform_inverter.py +++ b/tests/test_handler_transform_inverter.py @@ -70,7 +70,7 @@ def test_invert(self): data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index c302e04017..816f7eb1b1 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -117,7 +117,7 @@ def test_collation(self, _, transform, collate_fn, ndim): modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS)]) # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 dataset = CacheDataset(data, transform=modified_transform, progress=False) loader = DataLoader(dataset, num_workers, batch_size=self.batch_size, collate_fn=collate_fn) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index 6ba98ee919..820ef2e087 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -70,7 +70,7 @@ def test_invert(self): data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms - num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) @@ -84,7 +84,7 @@ def test_invert(self): nearest_interp=True, to_tensor=[True, False, False], device="cpu", - num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, + num_workers=num_workers, ) # execute 1 epoch From 4d7f1e6e6d4a2eb4341e97394ffaf5910940d778 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 13 Jul 2021 13:06:11 +0100 Subject: [PATCH 092/176] improve testcenterspatialcrop Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_center_spatial_crop.py | 30 +++++++++++++----------------- tests/utils.py | 4 ++-- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/tests/test_center_spatial_crop.py b/tests/test_center_spatial_crop.py index 3e828176a5..dced356785 100644 --- a/tests/test_center_spatial_crop.py +++ b/tests/test_center_spatial_crop.py @@ -16,34 +16,30 @@ from parameterized import parameterized from monai.transforms import CenterSpatialCrop +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [{"roi_size": [2, 2, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 3)] +TEST_SHAPES, TEST_VALUES = [], [] +for p in TEST_NDARRAYS: + TEST_SHAPES.append([{"roi_size": [2, 2, -1]}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 2, 3)]) -TEST_CASE_1 = [{"roi_size": [2, 2, 2]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 2)] - -TEST_CASE_2 = [ - {"roi_size": [2, 2]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2], [2, 3]]]), -] - -TEST_CASE_3 = [ - {"roi_size": [2, 2, 2]}, - torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), - (3, 2, 2, 2), -] + TEST_SHAPES.append([{"roi_size": [2, 2, 2]}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 2, 2)]) + TEST_VALUES.append([ + {"roi_size": [2, 2]}, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])), + p(np.array([[[1, 2], [2, 3]]])), + ]) class TestCenterSpatialCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) + @parameterized.expand(TEST_SHAPES) def test_shape(self, input_param, input_data, expected_shape): result = CenterSpatialCrop(**input_param)(input_data) np.testing.assert_allclose(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) + @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_data, expected_value): result = CenterSpatialCrop(**input_param)(input_data) - np.testing.assert_allclose(result, expected_value) + torch.testing.assert_allclose(result, expected_value, rtol=1e-7, atol=0) if __name__ == "__main__": diff --git a/tests/utils.py b/tests/utils.py index 9418feff26..ea43093d3d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -579,9 +579,9 @@ def query_memory(n=2): return ",".join([f"{int(x)}" for x in ids]) -TEST_NDARRAYS: Tuple[Callable] = (np.array, torch.tensor) # type: ignore +TEST_NDARRAYS: Tuple[Callable] = (np.array, torch.as_tensor) # type: ignore if torch.cuda.is_available(): - gpu_tensor: Callable = partial(torch.tensor, device="cuda") + gpu_tensor: Callable = partial(torch.as_tensor, device="cuda") TEST_NDARRAYS = TEST_NDARRAYS + (gpu_tensor,) # type: ignore From 0f2a6c16a7dd9f406781f5467059667f1e54a6b9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 13 Jul 2021 14:35:53 +0100 Subject: [PATCH 093/176] pickle generator Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 47212c20e4..631cac0fbd 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -193,7 +193,7 @@ class Randomizable(ABC, ThreadUnsafe): """ R: np.random.RandomState = np.random.RandomState() - R_torch = torch.Generator() + R_torch: Optional[torch._C.Generator] = None def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -218,7 +218,7 @@ def set_random_state( _seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed _seed = _seed % MAX_SEED self.R = np.random.RandomState(_seed) - self.R_torch = self.R_torch.manual_seed(int(_seed)) + self.R_torch = torch.manual_seed(int(_seed)) return self if state is not None: @@ -228,7 +228,6 @@ def set_random_state( return self self.R = np.random.RandomState() - self.R_torch = torch.Generator() return self def randomize(self, data: Any) -> None: From a186710b2bc186f00444bfda7d611b215dfa2b25 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 13 Jul 2021 14:36:03 +0100 Subject: [PATCH 094/176] rician gpu test Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 4 ++-- tests/test_rand_rician_noise.py | 21 +++++++++++---------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 9908324573..ba5d18dfb9 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -146,8 +146,8 @@ def _add_noise(self, img: DataObjects.Images, mean: float, std: float): im_shape = img.shape if isinstance(img, torch.Tensor): _std = float(torch.rand(1, generator=self.R_torch)) * std if self.sample_std else std - self._noise1 = torch.normal(mean, _std, size=im_shape, generator=self.R_torch).to(img.dtype) - self._noise2 = torch.normal(mean, _std, size=im_shape, generator=self.R_torch).to(img.dtype) + self._noise1 = torch.normal(mean, _std, size=im_shape, generator=self.R_torch).to(img.dtype).to(img.device) + self._noise2 = torch.normal(mean, _std, size=im_shape, generator=self.R_torch).to(img.dtype).to(img.device) return torch.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) else: _std = self.R.uniform(0, std) if self.sample_std else std diff --git a/tests/test_rand_rician_noise.py b/tests/test_rand_rician_noise.py index 4779895fdd..5330c72006 100644 --- a/tests/test_rand_rician_noise.py +++ b/tests/test_rand_rician_noise.py @@ -45,16 +45,17 @@ class TestRandRicianNoiseTorch(TorchImageTestCase2D): @parameterized.expand(TESTS) def test_correct_results(self, _, mean, std): seed = 0 - rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std) - rician_fn.set_random_state(seed) - noised = rician_fn(self.imt) - torch.manual_seed(seed) - _std = float(torch.rand(1)) * std - expected = torch.sqrt( - (self.imt + torch.normal(mean, _std, size=self.imt.shape)) ** 2 - + torch.normal(mean, _std, size=self.imt.shape) ** 2 - ) - np.testing.assert_allclose(expected, noised, atol=1e-5) + for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: + rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std) + rician_fn.set_random_state(seed) + noised = rician_fn(self.imt.to(device)) + torch.manual_seed(seed) + _std = float(torch.rand(1)) * std + expected = torch.sqrt( + (self.imt + torch.normal(mean, _std, size=self.imt.shape)) ** 2 + + torch.normal(mean, _std, size=self.imt.shape) ** 2 + ).to(device) + torch.testing.assert_allclose(expected, noised, rtol=1e-7, atol=1e-5) if __name__ == "__main__": From 68434383b6cb4e4985bbe0cd6e31a1426d72e044 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 13 Jul 2021 14:36:24 +0100 Subject: [PATCH 095/176] Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- docs/source/transforms.rst | 24 +++ monai/transforms/__init__.py | 10 + monai/transforms/croppad/array.py | 149 +++++++++++++- monai/transforms/croppad/dictionary.py | 186 +++++++++++++++++- monai/transforms/utility/array.py | 50 +++++ monai/transforms/utility/dictionary.py | 45 +++++ monai/transforms/utils.py | 171 +++++++++++++--- tests/test_classes_to_indices.py | 79 ++++++++ tests/test_classes_to_indicesd.py | 84 ++++++++ ...est_generate_label_classes_crop_centers.py | 58 ++++++ tests/test_inverse.py | 11 ++ tests/test_map_classes_to_indices.py | 101 ++++++++++ tests/test_rand_crop_by_label_classes.py | 93 +++++++++ tests/test_rand_crop_by_label_classesd.py | 77 ++++++++ 14 files changed, 1101 insertions(+), 37 deletions(-) create mode 100644 tests/test_classes_to_indices.py create mode 100644 tests/test_classes_to_indicesd.py create mode 100644 tests/test_generate_label_classes_crop_centers.py create mode 100644 tests/test_map_classes_to_indices.py create mode 100644 tests/test_rand_crop_by_label_classes.py create mode 100644 tests/test_rand_crop_by_label_classesd.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 8a890192c8..636d77d187 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -120,6 +120,12 @@ Crop and Pad :members: :special-members: __call__ +`RandCropByLabelClasses` +"""""""""""""""""""""""" +.. autoclass:: RandCropByLabelClasses + :members: + :special-members: __call__ + `ResizeWithPadOrCrop` """"""""""""""""""""" .. autoclass:: ResizeWithPadOrCrop @@ -604,6 +610,12 @@ Utility :members: :special-members: __call__ +`ClassesToIndices` +"""""""""""""""""" +.. autoclass:: ClassesToIndices + :members: + :special-members: __call__ + `ConvertToMultiChannelBasedOnBratsClasses` """""""""""""""""""""""""""""""""""""""""" .. autoclass:: ConvertToMultiChannelBasedOnBratsClasses @@ -700,6 +712,12 @@ Crop and Pad (Dict) :members: :special-members: __call__ +`RandCropByLabelClassesd` +""""""""""""""""""""""""" +.. autoclass:: RandCropByLabelClassesd + :members: + :special-members: __call__ + `ResizeWithPadOrCropd` """""""""""""""""""""" .. autoclass:: ResizeWithPadOrCropd @@ -1183,6 +1201,12 @@ Utility (Dict) :members: :special-members: __call__ +`ClassesToIndicesd` +""""""""""""""""""" +.. autoclass:: ClassesToIndicesd + :members: + :special-members: __call__ + `ConvertToMultiChannelBasedOnBratsClassesd` """"""""""""""""""""""""""""""""""""""""""" .. autoclass:: ConvertToMultiChannelBasedOnBratsClassesd diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ee4d609187..f25b3f0f0b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -19,6 +19,7 @@ CropForeground, DivisiblePad, Pad, + RandCropByLabelClasses, RandCropByPosNegLabel, RandScaleCrop, RandSpatialCrop, @@ -49,6 +50,9 @@ DivisiblePadD, DivisiblePadDict, NumpyPadModeSequence, + RandCropByLabelClassesd, + RandCropByLabelClassesD, + RandCropByLabelClassesDict, RandCropByPosNegLabeld, RandCropByPosNegLabelD, RandCropByPosNegLabelDict, @@ -315,6 +319,7 @@ AsChannelFirst, AsChannelLast, CastToType, + ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, DataStats, EnsureChannelFirst, @@ -352,6 +357,9 @@ CastToTyped, CastToTypeD, CastToTypeDict, + ClassesToIndicesd, + ClassesToIndicesD, + ClassesToIndicesDict, ConcatItemsd, ConcatItemsD, ConcatItemsDict, @@ -445,6 +453,7 @@ create_shear, create_translate, extreme_points_to_image, + generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, get_extreme_points, @@ -454,6 +463,7 @@ is_empty, is_positive, map_binary_to_indices, + map_classes_to_indices, map_spatial_axes, rand_choice, rescale_array, diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index c518bd29e1..c09f1f3eb6 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -26,10 +26,12 @@ from monai.transforms.transform import NumpyTransform, Randomizable, TorchOrNumpyTransform, Transform, convert_data_type from monai.transforms.utils import ( compute_divisible_spatial_size, + generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, is_positive, map_binary_to_indices, + map_classes_to_indices, weighted_patch_samples, ) from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple @@ -49,6 +51,7 @@ "CropForeground", "RandWeightedCrop", "RandCropByPosNegLabel", + "RandCropByLabelClasses", "ResizeWithPadOrCrop", "BoundingRect", ] @@ -842,7 +845,11 @@ def randomize( ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) if fg_indices is None or bg_indices is None: - fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) + if self.fg_indices is not None and self.bg_indices is not None: + fg_indices_ = self.fg_indices + bg_indices_ = self.bg_indices + else: + fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) else: fg_indices_ = fg_indices bg_indices_ = bg_indices @@ -878,12 +885,7 @@ def __call__( raise ValueError("label should be provided.") if image is None: image = self.image - if fg_indices is None or bg_indices is None: - if self.fg_indices is not None and self.bg_indices is not None: - fg_indices = self.fg_indices - bg_indices = self.bg_indices - else: - fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) + self.randomize(label, fg_indices, bg_indices, image) results: List[np.ndarray] = [] if self.centers is not None: @@ -894,6 +896,139 @@ def __call__( return results +class RandCropByLabelClasses(Randomizable, Transform): + """ + Crop random fixed sized regions with the center being a class based on the specified ratios of every class. + The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the + cropped images. For example, crop two (3 x 3) arrays from (5 x 5) array with `ratios=[1, 2, 3, 1]`:: + + image = np.array([ + [[0.0, 0.3, 0.4, 0.2, 0.0], + [0.0, 0.1, 0.2, 0.1, 0.4], + [0.0, 0.3, 0.5, 0.2, 0.0], + [0.1, 0.2, 0.1, 0.1, 0.0], + [0.0, 0.1, 0.2, 0.1, 0.0]] + ]) + label = np.array([ + [[0, 0, 0, 0, 0], + [0, 1, 2, 1, 0], + [0, 1, 3, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]] + ]) + cropper = RandCropByLabelClasses( + spatial_size=[3, 3], + ratios=[1, 2, 3, 1], + num_classes=4, + num_samples=2, + ) + label_samples = cropper(img=label, label=label, image=image) + + The 2 randomly cropped samples of `label` can be: + [[0, 1, 2], [[0, 0, 0], + [0, 1, 3], [1, 2, 1], + [0, 0, 0]] [1, 3, 0]] + + If a dimension of the expected spatial size is bigger than the input image size, + will not crop that dimension. So the cropped result may be smaller than expected size, and the cropped + results of several images may not have exactly same shape. + + Args: + spatial_size: the spatial size of the crop region e.g. [224, 224, 128]. + if a dimension of ROI size is bigger than image size, will not crop that dimension of the image. + if its components have non-positive values, the corresponding size of `label` will be used. + for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`, + the spatial size of output data will be [32, 40, 40]. + ratios: specified ratios of every class in the label to generate crop centers, including background class. + if None, every class will have the same ratio to generate crop centers. + label: the label image that is used for finding every classes, if None, must set at `self.__call__`. + num_classes: number of classes for argmax label, not necessary for One-Hot label. + num_samples: number of samples (crop regions) to take in each list. + image: if image is not None, only return the indices of every class that are within the valid + region of the image (``image > image_threshold``). + image_threshold: if enabled `image`, use ``image > image_threshold`` to + determine the valid image content area and select class indices only in this area. + indices: if provided pre-computed indices of every class, will ignore above `image` and + `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array + of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first + and cache the results for better performance. + + """ + + def __init__( + self, + spatial_size: Union[Sequence[int], int], + ratios: Optional[List[Union[float, int]]] = None, + label: Optional[np.ndarray] = None, + num_classes: Optional[int] = None, + num_samples: int = 1, + image: Optional[np.ndarray] = None, + image_threshold: float = 0.0, + indices: Optional[List[np.ndarray]] = None, + ) -> None: + self.spatial_size = ensure_tuple(spatial_size) + self.ratios = ratios + self.label = label + self.num_classes = num_classes + self.num_samples = num_samples + self.image = image + self.image_threshold = image_threshold + self.centers: Optional[List[List[np.ndarray]]] = None + self.indices = indices + + def randomize( + self, + label: np.ndarray, + indices: Optional[List[np.ndarray]] = None, + image: Optional[np.ndarray] = None, + ) -> None: + self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) + indices_: List[np.ndarray] + if indices is None: + if self.indices is not None: + indices_ = self.indices + else: + indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) + else: + indices_ = indices + self.centers = generate_label_classes_crop_centers( + self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R + ) + + def __call__( + self, + img: np.ndarray, + label: Optional[np.ndarray] = None, + image: Optional[np.ndarray] = None, + indices: Optional[List[np.ndarray]] = None, + ) -> List[np.ndarray]: + """ + Args: + img: input data to crop samples from based on the ratios of every class, assumes `img` is a + channel-first array. + label: the label image that is used for finding indices of every class, if None, use `self.label`. + image: optional image data to help select valid area, can be same as `img` or another image array. + use ``image > image_threshold`` to select the centers only in valid region. if None, use `self.image`. + indices: list of indices for every class in the image, used to randomly select crop centers. + + """ + if label is None: + label = self.label + if label is None: + raise ValueError("label should be provided.") + if image is None: + image = self.image + + self.randomize(label, indices, image) + results: List[np.ndarray] = [] + if self.centers is not None: + for center in self.centers: + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + results.append(cropper(img)) + + return results + + class ResizeWithPadOrCrop(Transform): """ Resize an image to a target spatial size by either centrally cropping the image or diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index a05f2377bd..947bc80a27 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -40,9 +40,11 @@ from monai.transforms.transform import MapTransform, Randomizable, convert_data_type from monai.transforms.utils import ( allow_missing_keys_mode, + generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, is_positive, map_binary_to_indices, + map_classes_to_indices, weighted_patch_samples, ) from monai.utils import ImageMetaKey as Key @@ -1096,9 +1098,188 @@ def get_as_np(x) -> np.ndarray: self.randomize(label, fg_indices, bg_indices, image) if not isinstance(self.spatial_size, tuple): - raise AssertionError + raise ValueError("spatial_size must be a valid tuple.") if self.centers is None: - raise AssertionError + raise ValueError("no available ROI centers to crop.") + + # initialize returned list with shallow copy to preserve key ordering + results: List[Dict[Hashable, np.ndarray]] = [dict(data) for _ in range(self.num_samples)] + + for i, center in enumerate(self.centers): + # fill in the extra keys with unmodified data + for key in set(data.keys()).difference(set(self.keys)): + results[i][key] = deepcopy(data[key]) + for key in self.key_iterator(d): + img = d[key] + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + orig_size = img.shape[1:] + results[i][key] = cropper(img) + self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) + # add `patch_index` to the meta data + for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): + meta_key = meta_key or f"{key}_{meta_key_postfix}" + if meta_key not in results[i]: + results[i][meta_key] = {} # type: ignore + results[i][meta_key][Key.PATCH_INDEX] = i + + return results + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + current_size = np.asarray(d[key].shape[1:]) + center = transform[InverseKeys.EXTRA_INFO]["center"] + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + # get required pad to start and end + pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) + pad_to_end = orig_size - current_size - pad_to_start + # interleave mins and maxes + pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) + inverse_transform = BorderPad(pad) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + + +class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): + """ + Dictionary-based version :py:class:`monai.transforms.RandCropByLabelClasses`. + Crop random fixed sized regions with the center being a class based on the specified ratios of every class. + The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the + cropped images. For example, crop two (3 x 3) arrays from (5 x 5) array with `ratios=[1, 2, 3, 1]`:: + + cropper = RandCropByLabelClassesd( + keys=["image", "label"], + label_key="label", + spatial_size=[3, 3], + ratios=[1, 2, 3, 1], + num_classes=4, + num_samples=2, + ) + data = { + "image": np.array([ + [[0.0, 0.3, 0.4, 0.2, 0.0], + [0.0, 0.1, 0.2, 0.1, 0.4], + [0.0, 0.3, 0.5, 0.2, 0.0], + [0.1, 0.2, 0.1, 0.1, 0.0], + [0.0, 0.1, 0.2, 0.1, 0.0]] + ]), + "label": np.array([ + [[0, 0, 0, 0, 0], + [0, 1, 2, 1, 0], + [0, 1, 3, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]] + ]), + } + result = cropper(data) + + The 2 randomly cropped samples of `label` can be: + [[0, 1, 2], [[0, 0, 0], + [0, 1, 3], [1, 2, 1], + [0, 0, 0]] [1, 3, 0]] + + If a dimension of the expected spatial size is bigger than the input image size, + will not crop that dimension. So the cropped result may be smaller than expected size, and the cropped + results of several images may not have exactly same shape. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + label_key: name of key for label image, this will be used for finding indices of every class. + spatial_size: the spatial size of the crop region e.g. [224, 224, 128]. + if a dimension of ROI size is bigger than image size, will not crop that dimension of the image. + if its components have non-positive values, the corresponding size of `label` will be used. + for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`, + the spatial size of output data will be [32, 40, 40]. + ratios: specified ratios of every class in the label to generate crop centers, including background class. + if None, every class will have the same ratio to generate crop centers. + num_classes: number of classes for argmax label, not necessary for One-Hot label. + num_samples: number of samples (crop regions) to take in each list. + image_key: if image_key is not None, only return the indices of every class that are within the valid + region of the image (``image > image_threshold``). + image_threshold: if enabled `image_key`, use ``image > image_threshold`` to + determine the valid image content area and select class indices only in this area. + indices_key: if provided pre-computed indices of every class, will ignore above `image` and + `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array + of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first + and cache the results for better performance. + meta_keys: explicitly indicate the key of the corresponding meta data dictionary. + used to add `patch_index` to the meta dict. + for example, for data with key `image`, the metadata by default is in `image_meta_dict`. + the meta data is a dictionary object which contains: filename, original_shape, etc. + it can be a sequence of string, map to the `keys`. + if None, will try to construct meta_keys by `key_{meta_key_postfix}`. + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + to the key data, default is `meta_dict`, the meta data is a dictionary object. + used to add `patch_index` to the meta dict. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + label_key: str, + spatial_size: Union[Sequence[int], int], + ratios: Optional[List[Union[float, int]]] = None, + num_classes: Optional[int] = None, + num_samples: int = 1, + image_key: Optional[str] = None, + image_threshold: float = 0.0, + indices_key: Optional[str] = None, + meta_keys: Optional[KeysCollection] = None, + meta_key_postfix: str = "meta_dict", + allow_missing_keys: bool = False, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + self.label_key = label_key + self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size + self.ratios = ratios + self.num_classes = num_classes + self.num_samples = num_samples + self.image_key = image_key + self.image_threshold = image_threshold + self.indices_key = indices_key + self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) + if len(self.keys) != len(self.meta_keys): + raise ValueError("meta_keys should have the same length as keys.") + self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) + self.centers: Optional[List[List[np.ndarray]]] = None + + def randomize( + self, + label: np.ndarray, + indices: Optional[List[np.ndarray]] = None, + image: Optional[np.ndarray] = None, + ) -> None: + self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) + indices_: List[np.ndarray] + if indices is None: + indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) + else: + indices_ = indices + self.centers = generate_label_classes_crop_centers( + self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R + ) + + def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarray]]: + d = dict(data) + label = d[self.label_key] + image = d[self.image_key] if self.image_key else None + indices = d.get(self.indices_key) if self.indices_key is not None else None + + self.randomize(label, indices, image) + if not isinstance(self.spatial_size, tuple): + raise ValueError("spatial_size must be a valid tuple.") + if self.centers is None: + raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering results: List[Dict[Hashable, Any]] = [dict(data) for _ in range(self.num_samples)] @@ -1275,5 +1456,6 @@ def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: CropForegroundD = CropForegroundDict = CropForegroundd RandWeightedCropD = RandWeightedCropDict = RandWeightedCropd RandCropByPosNegLabelD = RandCropByPosNegLabelDict = RandCropByPosNegLabeld +RandCropByLabelClassesD = RandCropByLabelClassesDict = RandCropByLabelClassesd ResizeWithPadOrCropD = ResizeWithPadOrCropDict = ResizeWithPadOrCropd BoundingRectD = BoundingRectDict = BoundingRectd diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index ba00cdde32..78f0f142ce 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -37,6 +37,7 @@ extreme_points_to_image, get_extreme_points, map_binary_to_indices, + map_classes_to_indices, ) from monai.utils import ensure_tuple, min_version, optional_import from monai.utils.enums import DataObjects @@ -68,6 +69,7 @@ "Lambda", "LabelToMask", "FgBgToIndices", + "ClassesToIndices", "ConvertToMultiChannelBasedOnBratsClasses", "AddExtremePointsChannel", "TorchVision", @@ -739,6 +741,54 @@ def __call__( return fg_indices, bg_indices +class ClassesToIndices(Transform): + def __init__( + self, + num_classes: Optional[int] = None, + image_threshold: float = 0.0, + output_shape: Optional[Sequence[int]] = None, + ) -> None: + """ + Compute indices of every class of the input label data, return a list of indices. + If no output_shape specified, output data will be 1 dim indices after flattening. + This transform can help pre-compute indices of the class regions for other transforms. + A typical usage is to randomly select indices of classes to crop. + The main logic is based on :py:class:`monai.transforms.utils.map_classes_to_indices`. + + Args: + num_classes: number of classes for argmax label, not necessary for One-Hot label. + image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to + determine the valid image content area and select only the indices of classes in this area. + output_shape: expected shape of output indices. if not None, unravel indices to specified shape. + + """ + self.num_classes = num_classes + self.image_threshold = image_threshold + self.output_shape = output_shape + + def __call__( + self, + label: np.ndarray, + image: Optional[np.ndarray] = None, + output_shape: Optional[Sequence[int]] = None, + ) -> List[np.ndarray]: + """ + Args: + label: input data to compute the indices of every class. + image: if image is not None, use ``image > image_threshold`` to define valid region, and only select + the indices within the valid region. + output_shape: expected shape of output indices. if None, use `self.output_shape` instead. + + """ + if output_shape is None: + output_shape = self.output_shape + indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) + if output_shape is not None: + indices = [np.stack([np.unravel_index(i, output_shape) for i in array]) for array in indices] + + return indices + + class ConvertToMultiChannelBasedOnBratsClasses(Transform): """ Convert labels to multi channels based on brats18 classes: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index d7d00a9853..494392165d 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -31,6 +31,7 @@ AsChannelFirst, AsChannelLast, CastToType, + ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, DataStats, EnsureChannelFirst, @@ -971,6 +972,49 @@ def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: return d +class ClassesToIndicesd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ClassesToIndices`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + indices_postfix: postfix to save the computed indices of all classes in dict. + for example, if computed on `label` and `postfix = "_cls_indices"`, the key will be `label_cls_indices`. + num_classes: number of classes for argmax label, not necessary for One-Hot label. + image_key: if image_key is not None, use ``image > image_threshold`` to define valid region, and only select + the indices within the valid region. + image_threshold: if enabled image_key, use ``image > image_threshold`` to determine the valid image content + area and select only the indices of classes in this area. + output_shape: expected shape of output indices. if not None, unravel indices to specified shape. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + indices_postfix: str = "_cls_indices", + num_classes: Optional[int] = None, + image_key: Optional[str] = None, + image_threshold: float = 0.0, + output_shape: Optional[Sequence[int]] = None, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.indices_postfix = indices_postfix + self.image_key = image_key + self.converter = ClassesToIndices(num_classes, image_threshold, output_shape) + + def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + image = d[self.image_key] if self.image_key else None + for key in self.key_iterator(d): + d[str(key) + self.indices_postfix] = self.converter(d[key], image) + + return d + + class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`. @@ -1197,6 +1241,7 @@ def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: LambdaD = LambdaDict = Lambdad LabelToMaskD = LabelToMaskDict = LabelToMaskd FgBgToIndicesD = FgBgToIndicesDict = FgBgToIndicesd +ClassesToIndicesD = ClassesToIndicesDict = ClassesToIndicesd ConvertToMultiChannelBasedOnBratsClassesD = ( ConvertToMultiChannelBasedOnBratsClassesDict ) = ConvertToMultiChannelBasedOnBratsClassesd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index a840b62bd3..28b0d365a8 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -55,8 +55,10 @@ "compute_divisible_spatial_size", "resize_center", "map_binary_to_indices", + "map_classes_to_indices", "weighted_patch_samples", "generate_pos_neg_label_crop_centers", + "generate_label_classes_crop_centers", "create_grid", "create_control_grid", "create_rotate", @@ -276,9 +278,57 @@ def map_binary_to_indices( bg_indices = np.nonzero(np.logical_and(img_flat, ~label_flat))[0] else: bg_indices = np.nonzero(~label_flat)[0] + return fg_indices, bg_indices +def map_classes_to_indices( + label: np.ndarray, + num_classes: Optional[int] = None, + image: Optional[np.ndarray] = None, + image_threshold: float = 0.0, +) -> List[np.ndarray]: + """ + Filter out indices of every class of the input label data, return the indices after fattening. + It can handle both One-Hot format label and Argmax format label, must provide `num_classes` for + Argmax label. + + For example: + ``label = np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])`` and `num_classes=3`, will return a list + which contains the indices of the 3 classes: + ``[np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])]`` + + Args: + label: use the label data to get the indices of every class. + num_classes: number of classes for argmax label, not necessary for One-Hot label. + image: if image is not None, only return the indices of every class that are within the valid + region of the image (``image > image_threshold``). + image_threshold: if enabled `image`, use ``image > image_threshold`` to + determine the valid image content area and select class indices only in this area. + + """ + img_flat: Optional[np.ndarray] = None + if image is not None: + img_flat = np.any(image > image_threshold, axis=0).ravel() + + indices: List[np.ndarray] = [] + # assuming the first dimension is channel + channels = len(label) + + num_classes_: int = channels + if channels == 1: + if num_classes is None: + raise ValueError("if not One-Hot format label, must provide the num_classes.") + num_classes_ = num_classes + + for c in range(num_classes_): + label_flat = np.any(label[c : c + 1] if channels > 1 else label == c, axis=0).ravel() + label_flat = np.logical_and(img_flat, label_flat) if img_flat is not None else label_flat + indices.append(np.nonzero(label_flat)[0]) + + return indices + + def weighted_patch_samples( spatial_size: Union[int, Sequence[int]], w: np.ndarray, @@ -323,6 +373,44 @@ def weighted_patch_samples( return [np.unravel_index(i, v_size) + diff for i in np.asarray(idx, dtype=int)] +def correct_crop_centers( + centers: List[np.ndarray], spatial_size: Union[Sequence[int], int], label_spatial_shape: Sequence[int] +) -> List[np.ndarray]: + """ + Utility to correct the crop center if the crop size is bigger than the image size. + + Args: + ceters: pre-computed crop centers, will correct based on the valid region. + spatial_size: spatial size of the ROIs to be sampled. + label_spatial_shape: spatial shape of the original label data to compare with ROI. + + """ + spatial_size = fall_back_tuple(spatial_size, default=label_spatial_shape) + if not (np.subtract(label_spatial_shape, spatial_size) >= 0).all(): + raise ValueError("The size of the proposed random crop ROI is larger than the image size.") + + # Select subregion to assure valid roi + valid_start = np.floor_divide(spatial_size, 2) + # add 1 for random + valid_end = np.subtract(label_spatial_shape + np.array(1), spatial_size / np.array(2)).astype(np.uint16) + # int generation to have full range on upper side, but subtract unfloored size/2 to prevent rounded range + # from being too high + for i, valid_s in enumerate(valid_start): + # need this because np.random.randint does not work with same start and end + if valid_s == valid_end[i]: + valid_end[i] += 1 + + for i, c in enumerate(centers): + center_i = c + if c < valid_start[i]: + center_i = valid_start[i] + if c >= valid_end[i]: + center_i = valid_end[i] - 1 + centers[i] = center_i + + return centers + + def generate_pos_neg_label_crop_centers( spatial_size: Union[Sequence[int], int], num_samples: int, @@ -352,33 +440,6 @@ def generate_pos_neg_label_crop_centers( """ if rand_state is None: rand_state = np.random.random.__self__ # type: ignore - spatial_size = fall_back_tuple(spatial_size, default=label_spatial_shape) - if not (np.subtract(label_spatial_shape, spatial_size) >= 0).all(): - raise ValueError("The size of the proposed random crop ROI is larger than the image size.") - - # Select subregion to assure valid roi - valid_start = np.floor_divide(spatial_size, 2) - # add 1 for random - valid_end = np.subtract(label_spatial_shape + np.array(1), spatial_size / np.array(2)).astype(np.uint16) - # int generation to have full range on upper side, but subtract unfloored size/2 to prevent rounded range - # from being too high - for i, valid_s in enumerate( - valid_start - ): # need this because np.random.randint does not work with same start and end - if valid_s == valid_end[i]: - valid_end[i] += 1 - - def _correct_centers( - center_ori: List[np.ndarray], valid_start: np.ndarray, valid_end: np.ndarray - ) -> List[np.ndarray]: - for i, c in enumerate(center_ori): - center_i = c - if c < valid_start[i]: - center_i = valid_start[i] - if c >= valid_end[i]: - center_i = valid_end[i] - 1 - center_ori[i] = center_i - return center_ori centers = [] fg_indices, bg_indices = np.asarray(fg_indices), np.asarray(bg_indices) @@ -398,7 +459,61 @@ def _correct_centers( center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) # shift center to range of valid centers center_ori = list(center) - centers.append(_correct_centers(center_ori, valid_start, valid_end)) + centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) + + return centers + + +def generate_label_classes_crop_centers( + spatial_size: Union[Sequence[int], int], + num_samples: int, + label_spatial_shape: Sequence[int], + indices: List[np.ndarray], + ratios: Optional[List[Union[float, int]]] = None, + rand_state: Optional[np.random.RandomState] = None, +) -> List[List[np.ndarray]]: + """ + Generate valid sample locations based on the specified ratios of label classes. + Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W] + + Args: + spatial_size: spatial size of the ROIs to be sampled. + num_samples: total sample centers to be generated. + label_spatial_shape: spatial shape of the original label data to unravel selected centers. + indices: sequence of pre-computed foreground indices of every class in 1 dimension. + ratios: ratios of every class in the label to generate crop centers, including background class. + if None, every class will have the same ratio to generate crop centers. + rand_state: numpy randomState object to align with other modules. + + """ + if rand_state is None: + rand_state = np.random.random.__self__ # type: ignore + + if num_samples < 1: + raise ValueError("num_samples must be an int number and greater than 0.") + ratios_: List[Union[float, int]] = ([1] * len(indices)) if ratios is None else ratios + if len(ratios_) != len(indices): + raise ValueError("random crop radios must match the number of indices of classes.") + if any([i < 0 for i in ratios_]): + raise ValueError("ratios should not contain negative number.") + + # ensure indices are numpy array + indices = [np.asarray(i) for i in indices] + for i, array in enumerate(indices): + if len(array) == 0: + warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.") + ratios_[i] = 0 + + centers = [] + classes = rand_state.choice(len(ratios_), size=num_samples, p=np.asarray(ratios_) / np.sum(ratios_)) + for i in classes: + # randomly select the indices of a class based on the ratios + indices_to_use = indices[i] + random_int = rand_state.randint(len(indices_to_use)) + center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) + # shift center to range of valid centers + center_ori = list(center) + centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) return centers diff --git a/tests/test_classes_to_indices.py b/tests/test_classes_to_indices.py new file mode 100644 index 0000000000..0ba3dd094a --- /dev/null +++ b/tests/test_classes_to_indices.py @@ -0,0 +1,79 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ClassesToIndices + +TEST_CASE_1 = [ + # test Argmax data + {"num_classes": 3, "image_threshold": 0.0}, + np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + None, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], +] + +TEST_CASE_2 = [ + {"num_classes": 3, "image_threshold": 60}, + np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], +] + +TEST_CASE_3 = [ + # test One-Hot data + {"image_threshold": 0.0}, + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + None, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], +] + +TEST_CASE_4 = [ + {"num_classes": None, "image_threshold": 60}, + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], +] + +TEST_CASE_5 = [ + # test output_shape + {"num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, + np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + None, + [np.array([[0, 0], [1, 1], [2, 2]]), np.array([[0, 1], [1, 2], [2, 0]]), np.array([[0, 2], [1, 0], [2, 1]])], +] + + +class TestClassesToIndices(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_value(self, input_args, label, image, expected_indices): + indices = ClassesToIndices(**input_args)(label, image) + for i, e in zip(indices, expected_indices): + np.testing.assert_allclose(i, e) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_classes_to_indicesd.py b/tests/test_classes_to_indicesd.py new file mode 100644 index 0000000000..67fac95c8c --- /dev/null +++ b/tests/test_classes_to_indicesd.py @@ -0,0 +1,84 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ClassesToIndicesd + +TEST_CASE_1 = [ + # test Argmax data + {"keys": "label", "num_classes": 3, "image_threshold": 0.0}, + {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])}, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], +] + +TEST_CASE_2 = [ + {"keys": "label", "image_key": "image", "num_classes": 3, "image_threshold": 60}, + { + "label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + }, + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], +] + +TEST_CASE_3 = [ + # test One-Hot data + {"keys": "label", "image_threshold": 0.0}, + { + "label": np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], +] + +TEST_CASE_4 = [ + {"keys": "label", "image_key": "image", "num_classes": None, "image_threshold": 60}, + { + "label": np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + }, + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], +] + +TEST_CASE_5 = [ + # test output_shape + {"keys": "label", "indices_postfix": "cls", "num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, + {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])}, + [np.array([[0, 0], [1, 1], [2, 2]]), np.array([[0, 1], [1, 2], [2, 0]]), np.array([[0, 2], [1, 0], [2, 1]])], +] + + +class TestClassesToIndicesd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_value(self, input_args, input_data, expected_indices): + result = ClassesToIndicesd(**input_args)(input_data) + key_postfix = input_args.get("indices_postfix") + key_postfix = "_cls_indices" if key_postfix is None else key_postfix + for i, e in zip(result["label" + key_postfix], expected_indices): + np.testing.assert_allclose(i, e) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_label_classes_crop_centers.py b/tests/test_generate_label_classes_crop_centers.py new file mode 100644 index 0000000000..38f2a3e0d1 --- /dev/null +++ b/tests/test_generate_label_classes_crop_centers.py @@ -0,0 +1,58 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import generate_label_classes_crop_centers + +TEST_CASE_1 = [ + { + "spatial_size": [2, 2, 2], + "num_samples": 2, + "ratios": [1, 2], + "label_spatial_shape": [3, 3, 3], + "indices": [[3, 12, 21], [1, 9, 18]], + "rand_state": np.random.RandomState(), + }, + list, + 2, + 3, +] + +TEST_CASE_2 = [ + { + "spatial_size": [2, 2, 2], + "num_samples": 1, + "ratios": None, + "label_spatial_shape": [3, 3, 3], + "indices": [[3, 12, 21], [1, 9, 18]], + "rand_state": np.random.RandomState(), + }, + list, + 1, + 3, +] + + +class TestGenerateLabelClassesCropCenters(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_type_shape(self, input_data, expected_type, expected_count, expected_shape): + result = generate_label_classes_crop_centers(**input_data) + self.assertIsInstance(result, expected_type) + self.assertEqual(len(result), expected_count) + self.assertEqual(len(result[0]), expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_inverse.py b/tests/test_inverse.py index eb53743d12..bf390ee3a3 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -39,6 +39,7 @@ Orientationd, RandAffined, RandAxisFlipd, + RandCropByLabelClassesd, RandCropByPosNegLabeld, RandFlipd, Randomizable, @@ -446,6 +447,15 @@ ) ) +TESTS.append( + ( + "RandCropByLabelClassesd 2d", + "2D", + 1e-7, + RandCropByLabelClassesd(KEYS, "label", (99, 96), ratios=[1, 2, 3, 4, 5], num_classes=5, num_samples=10), + ) +) + TESTS.append( ( "RandCropByPosNegLabeld 2d", @@ -479,6 +489,7 @@ NUM_SAMPLES = 5 N_SAMPLES_TESTS = [ + [RandCropByLabelClassesd(KEYS, "label", (110, 99), [1, 2, 3, 4, 5], num_classes=5, num_samples=NUM_SAMPLES)], [RandCropByPosNegLabeld(KEYS, "label", (110, 99), num_samples=NUM_SAMPLES)], [RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=NUM_SAMPLES, random_size=False)], [RandWeightedCropd(KEYS, "label", (90, 91), num_samples=NUM_SAMPLES)], diff --git a/tests/test_map_classes_to_indices.py b/tests/test_map_classes_to_indices.py new file mode 100644 index 0000000000..2320954520 --- /dev/null +++ b/tests/test_map_classes_to_indices.py @@ -0,0 +1,101 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import map_classes_to_indices + +TEST_CASE_1 = [ + # test Argmax data + {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 3, "image": None, "image_threshold": 0.0}, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], +] + +TEST_CASE_2 = [ + { + "label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + "num_classes": 3, + "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + "image_threshold": 60, + }, + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], +] + +TEST_CASE_3 = [ + # test One-Hot data + { + "label": np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + "image": None, + "image_threshold": 0.0, + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], +] + +TEST_CASE_4 = [ + { + "label": np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + "num_classes": None, + "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + "image_threshold": 60, + }, + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], +] + +TEST_CASE_5 = [ + # test empty class + {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 5, "image": None, "image_threshold": 0.0}, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], +] + +TEST_CASE_6 = [ + # test empty class + { + "label": np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + ] + ), + "image": None, + "image_threshold": 0.0, + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], +] + + +class TestMapClassesToIndices(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + def test_value(self, input_data, expected_indices): + indices = map_classes_to_indices(**input_data) + for i, e in zip(indices, expected_indices): + np.testing.assert_allclose(i, e) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py new file mode 100644 index 0000000000..b21f971042 --- /dev/null +++ b/tests/test_rand_crop_by_label_classes.py @@ -0,0 +1,93 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ClassesToIndices, RandCropByLabelClasses + +TEST_CASE_0 = [ + # One-Hot label + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + list, + (3, 2, 2, 3), +] + +TEST_CASE_1 = [ + # Argmax label + { + "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + list, + (3, 2, 2, 2), +] + +TEST_CASE_2 = [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + { + "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + }, + list, + (3, 2, 2, 2), +] + + +class TestRandCropByLabelClasses(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + def test_type_shape(self, input_param, input_data, expected_type, expected_shape): + result = RandCropByLabelClasses(**input_param)(**input_data) + self.assertIsInstance(result, expected_type) + self.assertTupleEqual(result[0].shape, expected_shape) + + @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + def test_indices(self, input_param, input_data, expected_type, expected_shape): + input_param["indices"] = ClassesToIndices(num_classes=input_param["num_classes"])(input_param["label"]) + result = RandCropByLabelClasses(**input_param)(**input_data) + self.assertIsInstance(result, expected_type) + self.assertTupleEqual(result[0].shape, expected_shape) + # test set indices at runtime + input_data["indices"] = input_param["indices"] + result = RandCropByLabelClasses(**input_param)(**input_data) + self.assertIsInstance(result, expected_type) + self.assertTupleEqual(result[0].shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py new file mode 100644 index 0000000000..829096953b --- /dev/null +++ b/tests/test_rand_crop_by_label_classesd.py @@ -0,0 +1,77 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd + +TEST_CASE_0 = [ + # One-Hot label + { + "keys": "img", + "label_key": "label", + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + }, + list, + (3, 2, 2, 3), +] + +TEST_CASE_1 = [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), + }, + list, + (3, 2, 2, 2), +] + + +class TestRandCropByLabelClassesd(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + def test_type_shape(self, input_param, input_data, expected_type, expected_shape): + result = RandCropByLabelClassesd(**input_param)(input_data) + self.assertIsInstance(result, expected_type) + self.assertTupleEqual(result[0]["img"].shape, expected_shape) + # test with pre-computed indices + input_data = ClassesToIndicesd(keys="label", num_classes=input_param["num_classes"])(input_data) + input_param["indices_key"] = "label_cls_indices" + result = RandCropByLabelClassesd(**input_param)(input_data) + self.assertIsInstance(result, expected_type) + self.assertTupleEqual(result[0]["img"].shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() From 1e8dfaaa5fe14f774401199904af270a7ab77059 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 13 Jul 2021 16:32:56 +0100 Subject: [PATCH 096/176] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_center_spatial_crop.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/test_center_spatial_crop.py b/tests/test_center_spatial_crop.py index dced356785..51abe7b07f 100644 --- a/tests/test_center_spatial_crop.py +++ b/tests/test_center_spatial_crop.py @@ -24,11 +24,14 @@ TEST_SHAPES.append([{"roi_size": [2, 2, 2]}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 2, 2)]) - TEST_VALUES.append([ - {"roi_size": [2, 2]}, - p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])), - p(np.array([[[1, 2], [2, 3]]])), - ]) + TEST_VALUES.append( + [ + {"roi_size": [2, 2]}, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])), + p(np.array([[[1, 2], [2, 3]]])), + ] + ) + class TestCenterSpatialCrop(unittest.TestCase): @parameterized.expand(TEST_SHAPES) From 3181f4c614b5c04c02bc7eb688e6230753856c42 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 13 Jul 2021 17:14:52 +0100 Subject: [PATCH 097/176] flake8 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 2 +- monai/transforms/croppad/dictionary.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index c09f1f3eb6..f636bdae4e 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1024,7 +1024,7 @@ def __call__( if self.centers is not None: for center in self.centers: cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - results.append(cropper(img)) + results.append(cropper(img)) # type: ignore return results diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 947bc80a27..70bb6beab1 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -1113,7 +1113,7 @@ def get_as_np(x) -> np.ndarray: img = d[key] cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore orig_size = img.shape[1:] - results[i][key] = cropper(img) + results[i][key] = cropper(img) # type: ignore self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): @@ -1140,7 +1140,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) inverse_transform = BorderPad(pad) # Apply inverse transform - d[key] = inverse_transform(d[key]) + d[key] = inverse_transform(d[key]) # type: ignore # Remove the applied transform self.pop_transform(d, key) From 482286596eae66c8d58bd6fdfdd3cb379c8a3631 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 14 Jul 2021 11:15:37 +0100 Subject: [PATCH 098/176] Merge remote-tracking branch 'MONAI/dev' into feature/2231-transforms Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- docs/source/handlers.rst | 5 + docs/source/transforms.rst | 4 +- monai/_extensions/gmm/gmm_cpu.cpp | 2 +- monai/apps/datasets.py | 13 ++- monai/apps/mmars/mmars.py | 4 +- .../bilateral/bilateralfilter_cpu_phl.cpp | 6 +- monai/data/csv_saver.py | 2 +- monai/data/dataset.py | 81 +++++++++------- monai/data/decathlon_datalist.py | 4 +- monai/data/grid_dataset.py | 2 +- monai/data/image_reader.py | 14 ++- monai/data/iterable_dataset.py | 6 +- monai/data/utils.py | 6 +- monai/engines/utils.py | 4 +- monai/engines/workflow.py | 10 +- monai/handlers/__init__.py | 1 + monai/handlers/checkpoint_loader.py | 2 +- monai/handlers/classification_saver.py | 2 +- monai/handlers/decollate_batch.py | 94 +++++++++++++++++++ monai/handlers/metrics_saver.py | 2 +- monai/handlers/postprocessing.py | 15 ++- monai/handlers/stats_handler.py | 32 ++++--- monai/handlers/transform_inverter.py | 2 +- monai/handlers/utils.py | 27 +++--- monai/losses/deform.py | 4 +- monai/losses/dice.py | 24 ++--- monai/losses/multi_scale.py | 12 ++- monai/losses/spatial_mask.py | 8 +- monai/metrics/meandice.py | 7 +- monai/metrics/metric.py | 6 +- monai/metrics/regression.py | 18 ++-- monai/metrics/utils.py | 2 +- monai/networks/blocks/crf.py | 7 +- monai/networks/blocks/downsample.py | 3 +- monai/networks/blocks/dynunet_block.py | 3 +- monai/networks/blocks/dynunet_block_v1.py | 12 +-- monai/networks/blocks/fcn.py | 18 ++-- monai/networks/nets/resnet.py | 2 +- monai/transforms/compose.py | 2 +- monai/transforms/croppad/array.py | 10 +- monai/transforms/croppad/batch.py | 6 +- monai/transforms/croppad/dictionary.py | 10 +- monai/transforms/intensity/array.py | 6 +- monai/transforms/intensity/dictionary.py | 2 +- monai/transforms/inverse_batch_transform.py | 38 ++++++-- monai/transforms/post/dictionary.py | 4 +- monai/transforms/spatial/array.py | 4 +- monai/transforms/spatial/dictionary.py | 2 +- monai/transforms/transform.py | 5 +- monai/transforms/utility/array.py | 2 +- monai/transforms/utils.py | 13 +-- monai/utils/dist.py | 2 +- monai/utils/jupyter_utils.py | 9 +- monai/utils/misc.py | 6 +- monai/utils/state_cacher.py | 5 +- tests/min_tests.py | 1 + tests/test_decollate.py | 36 +++++++ tests/test_handler_decollate_batch.py | 63 +++++++++++++ tests/test_handler_post_processing.py | 3 +- tests/test_lmdbdataset.py | 5 +- tests/test_resize_with_pad_or_crop.py | 2 +- tests/test_resize_with_pad_or_cropd.py | 2 +- 62 files changed, 473 insertions(+), 231 deletions(-) create mode 100644 monai/handlers/decollate_batch.py create mode 100644 tests/test_handler_decollate_batch.py diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 32da051ab5..096777cdef 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -160,6 +160,11 @@ Post processing .. autoclass:: PostProcessing :members: +Decollate batch +--------------- +.. autoclass:: DecollateBatch + :members: + Utilities --------- .. automodule:: monai.handlers.utils diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 636d77d187..fcd9adba94 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -742,8 +742,8 @@ Crop and Pad (Dict) :members: :special-members: __call__ -Instensity (Dict) -^^^^^^^^^^^^^^^^^ +Intensity (Dict) +^^^^^^^^^^^^^^^^ `RandGaussianNoised` """""""""""""""""""" diff --git a/monai/_extensions/gmm/gmm_cpu.cpp b/monai/_extensions/gmm/gmm_cpu.cpp index 9bb9bd92fc..144e66806c 100644 --- a/monai/_extensions/gmm/gmm_cpu.cpp +++ b/monai/_extensions/gmm/gmm_cpu.cpp @@ -17,7 +17,7 @@ limitations under the License. void learn_cpu(const float* input, const int* labels, float* gmm, float* scratch_memory, unsigned int batch_count, unsigned int element_count) { - throw std::invalid_argument("GMM recieved a cpu tensor but is not yet implemented for the cpu"); + throw std::invalid_argument("GMM received a cpu tensor but is not yet implemented for the cpu"); } void apply_cpu(const float* gmm, const float* input, float* output, unsigned int batch_count, unsigned int element_count) diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index c2daec6e12..c766914026 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -129,7 +129,6 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: image_class.extend([i] * num_each[i]) class_name.extend([class_names[i]] * num_each[i]) - data = [] length = len(image_files_list) indices = np.arange(length) self.randomize(indices) @@ -147,10 +146,14 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].' ) - for i in section_indices: - data.append({"image": image_files_list[i], "label": image_class[i], "class_name": class_name[i]}) - - return data + return [ + { + "image": image_files_list[i], + "label": image_class[i], + "class_name": class_name[i], + } + for i in section_indices + ] class DecathlonDataset(Randomizable, CacheDataset): diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py index a6a6d6eeae..e7ff28ce44 100644 --- a/monai/apps/mmars/mmars.py +++ b/monai/apps/mmars/mmars.py @@ -79,7 +79,7 @@ def _get_all_ngc_models(pattern, page_index=0, page_size=50): else: raise ValueError("NGC API requires requests package. Please install it.") model_list = json.loads(resp.text) - model_dict = dict() + model_dict = {} for result in model_list["results"]: for model in result["resources"]: current_res_id = model["resourceId"] @@ -136,7 +136,7 @@ def download_mmar(item, mmar_dir=None, progress: bool = True, api: bool = False, model_dict = _get_all_ngc_models(item) if len(model_dict) == 0: raise ValueError(f"api query returns no item for pattern {item}. Please change or shorten it.") - model_dir_list = list() + model_dir_list = [] for k, v in model_dict.items(): ver = v["latest"] if version == -1 else str(version) download_url = _get_ngc_url(k, ver) diff --git a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp index 1fb48cb6c9..847a452396 100644 --- a/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp +++ b/monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp @@ -51,11 +51,11 @@ void BilateralFilterPHLCpu( } // Spatial features - int offsetRemanider = i; + int offsetRemainder = i; for (int d = 0; d < desc.dimensions; d++) { - int coord = offsetRemanider / desc.strides[d]; - offsetRemanider -= coord * desc.strides[d]; + int coord = offsetRemainder / desc.strides[d]; + offsetRemainder -= coord * desc.strides[d]; features[i * featureChannels + desc.channelCount + d] = invSpatialSigma * coord; } diff --git a/monai/data/csv_saver.py b/monai/data/csv_saver.py index 59bbddb696..c333f69c4c 100644 --- a/monai/data/csv_saver.py +++ b/monai/data/csv_saver.py @@ -44,7 +44,7 @@ def __init__( output_dir: output CSV file directory. filename: name of the saved CSV file name. overwrite: whether to overwriting existing CSV file content, if True, will clear the file before saving. - otherwise, will apend new content to the CSV file. + otherwise, will append new content to the CSV file. flush: whether to write the cache data to CSV file immediately when `save_batch` and clear the cache. default to False. diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 4d18bd4e0d..e8ec02e2a8 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -259,9 +259,7 @@ def _cachecheck(self, item_transformed): try: return torch.load(hashfile) except PermissionError as e: - if sys.platform == "win32": - pass # windows machine multiprocessing not efficiently supported - else: + if sys.platform != "win32": raise e _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed @@ -413,7 +411,11 @@ def __init__( self.lmdb_kwargs = lmdb_kwargs or {} if not self.lmdb_kwargs.get("map_size", 0): self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size + # lmdb is single-writer multi-reader by default + # the cache is created without multi-threading self._read_env = None + # this runs on the primary thread/process + self._fill_cache_start_reader(show_progress=self.progress) print(f"Accessing lmdb file: {self.db_file.absolute()}.") def set_data(self, data: Sequence): @@ -422,43 +424,53 @@ def set_data(self, data: Sequence): """ super().set_data(data=data) - self._read_env = None + self._read_env = self._fill_cache_start_reader(show_progress=self.progress) + + def _fill_cache_start_reader(self, show_progress=True): + """ + Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write. + This method can be used with multiple processes, but it may have a negative impact on the performance. - def _fill_cache_start_reader(self): + Args: + show_progress: whether to show the progress bar if possible. + """ # create cache self.lmdb_kwargs["readonly"] = False env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs) - if self.progress and not has_tqdm: + if show_progress and not has_tqdm: warnings.warn("LMDBDataset: tqdm is not installed. not displaying the caching progress.") - for item in tqdm(self.data) if has_tqdm and self.progress else self.data: - key = self.hash_func(item) - done, retry, val = False, 5, None - while not done and retry > 0: - try: - with env.begin(write=True) as txn: - with txn.cursor() as cursor: + with env.begin(write=False) as search_txn: + for item in tqdm(self.data) if has_tqdm and show_progress else self.data: + key = self.hash_func(item) + done, retry, val = False, 5, None + while not done and retry > 0: + try: + with search_txn.cursor() as cursor: done = cursor.set_key(key) - if done: - continue + if done: + continue if val is None: val = self._pre_transform(deepcopy(item)) # keep the original hashed val = pickle.dumps(val, protocol=self.pickle_protocol) - txn.put(key, val) - done = True - except lmdb.MapFullError: - done, retry = False, retry - 1 + with env.begin(write=True) as txn: + txn.put(key, val) + done = True + except lmdb.MapFullError: + done, retry = False, retry - 1 + size = env.info()["map_size"] + new_size = size * 2 + warnings.warn( + f"Resizing the cache database from {int(size) >> 20}MB" f" to {int(new_size) >> 20}MB." + ) + env.set_mapsize(new_size) + except lmdb.MapResizedError: + # the mapsize is increased by another process + # set_mapsize with a size of 0 to adopt the new size + env.set_mapsize(0) + if not done: # still has the map full error size = env.info()["map_size"] - new_size = size * 2 - warnings.warn(f"Resizing the cache database from {int(size) >> 20}MB to {int(new_size) >> 20}MB.") - env.set_mapsize(new_size) - except lmdb.MapResizedError: - # the mapsize is increased by another process - # set_mapsize with a size of 0 to adopt the new size, - env.set_mapsize(0) - if not done: # still has the map full error - size = env.info()["map_size"] - env.close() - raise ValueError(f"LMDB map size reached, increase size above current size of {size}.") + env.close() + raise ValueError(f"LMDB map size reached, increase size above current size of {size}.") size = env.info()["map_size"] env.close() # read-only database env @@ -476,7 +488,8 @@ def _cachecheck(self, item_transformed): """ if self._read_env is None: - self._read_env = self._fill_cache_start_reader() + # this runs on multiple processes, each one should have its own env. + self._read_env = self._fill_cache_start_reader(show_progress=False) with self._read_env.begin(write=False) as txn: data = txn.get(self.hash_func(item_transformed)) if data is None: @@ -582,7 +595,7 @@ def set_data(self, data: Sequence): """ Set the input data and run deterministic transforms to generate cache content. - Note: should call this func after an entire epoch and must set `persisten_workers=False` + Note: should call this func after an entire epoch and must set `persistent_workers=False` in PyTorch DataLoader, because it needs to create new worker processes based on new generated cache content. @@ -1130,10 +1143,10 @@ def _transform(self, index: int): class CSVDataset(Dataset): """ Dataset to load data from CSV files and generate a list of dictionaries, - every dictionay maps to a row of the CSV file, and the keys of dictionary + every dictionary maps to a row of the CSV file, and the keys of dictionary map to the column names of the CSV file. - It can load multiple CSV files and join the tables with addtional `kwargs` arg. + It can load multiple CSV files and join the tables with additional `kwargs` arg. Support to only load specific rows and columns. And it can also group several loaded columns to generate a new column, for example, set `col_groups={"meta": ["meta_0", "meta_1", "meta_2"]}`, output can be:: diff --git a/monai/data/decathlon_datalist.py b/monai/data/decathlon_datalist.py index 11fb5edd28..663b68a08e 100644 --- a/monai/data/decathlon_datalist.py +++ b/monai/data/decathlon_datalist.py @@ -71,9 +71,7 @@ def _append_paths(base_dir: str, is_segmentation: bool, items: List[Dict]) -> Li if not isinstance(item, dict): raise TypeError(f"Every item in items must be a dict but got {type(item).__name__}.") for k, v in item.items(): - if k == "image": - item[k] = _compute_path(base_dir, v, check_path=False) - elif is_segmentation and k == "label": + if k == "image" or is_segmentation and k == "label": item[k] = _compute_path(base_dir, v, check_path=False) else: # for other items, auto detect whether it's a valid path diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index b789b9c032..c4efe3ad2a 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -146,7 +146,7 @@ def __iter__(self): if worker_info is not None: # split workload per_worker = int(np.ceil((iter_end - iter_start) / float(worker_info.num_workers))) - iter_start = iter_start + worker_info.id * per_worker + iter_start += worker_info.id * per_worker iter_end = min(iter_start + per_worker, iter_end) for index in range(iter_start, iter_end): diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 3e2110dc77..cfe8f29f04 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -115,14 +115,12 @@ def _copy_compatible_dict(from_dict: Dict, to_dict: Dict): def _stack_images(image_list: List, meta_dict: Dict): - if len(image_list) > 1: - if meta_dict.get("original_channel_dim", None) not in ("no_channel", None): - raise RuntimeError("can not read a list of images which already have channel dimension.") - meta_dict["original_channel_dim"] = 0 - img_array = np.stack(image_list, axis=0) - else: - img_array = image_list[0] - return img_array + if len(image_list) <= 1: + return image_list[0] + if meta_dict.get("original_channel_dim", None) not in ("no_channel", None): + raise RuntimeError("can not read a list of images which already have channel dimension.") + meta_dict["original_channel_dim"] = 0 + return np.stack(image_list, axis=0) class ITKReader(ImageReader): diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index 75bab462d4..c4fc252586 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -54,12 +54,12 @@ def __iter__(self): class CSVIterableDataset(IterableDataset): """ Iterable dataset to load CSV files and generate dictionary data. - It can be helpful when loading extemely big CSV files that can't read into memory directly. + It can be helpful when loading extremely big CSV files that can't read into memory directly. To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers, - every process executes tranforms on part of every loaded chunk. + every process executes transforms on part of every loaded chunk. Note: the order of output data may not match data source in multi-processing mode. - It can load data from multiple CSV files and join the tables with addtional `kwargs` arg. + It can load data from multiple CSV files and join the tables with additional `kwargs` arg. Support to only load specific columns. And it can also group several loaded columns to generate a new column, for example, set `col_groups={"meta": ["meta_0", "meta_1", "meta_2"]}`, output can be:: diff --git a/monai/data/utils.py b/monai/data/utils.py index ed6a956108..716c344218 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -551,7 +551,7 @@ def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = Tru Args: affine (nxn matrix): a square matrix. scale: new scaling factor along each dimension. if the components of the `scale` are non-positive values, - will use the corresponding components of the origial pixdim, which is computed from the `affine`. + will use the corresponding components of the original pixdim, which is computed from the `affine`. diagonal: whether to return a diagonal scaling matrix. Defaults to True. @@ -1105,11 +1105,11 @@ def convert_tables_to_dicts( if isinstance(col_types, dict): # fill default values for NaN defaults = {k: v["default"] for k, v in col_types.items() if v is not None and v.get("default") is not None} - if len(defaults) > 0: + if defaults: data_ = data_.fillna(value=defaults) # convert data types types = {k: v["type"] for k, v in col_types.items() if v is not None and "type" in v} - if len(types) > 0: + if types: data_ = data_.astype(dtype=types) data: List[Dict] = data_.to_dict(orient="records") diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 7ed1ee029f..c94cc16916 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -87,7 +87,7 @@ def get_devices_spec(devices: Optional[Sequence[torch.device]] = None) -> List[t if devices is None: devices = [torch.device(f"cuda:{d:d}") for d in range(torch.cuda.device_count())] - if len(devices) == 0: + if not devices: raise RuntimeError("No GPU devices available.") elif len(devices) == 0: @@ -115,7 +115,7 @@ def default_prepare_batch( """ if not isinstance(batchdata, dict): raise AssertionError("default prepare_batch expects dictionary input data.") - if isinstance(batchdata.get(CommonKeys.LABEL, None), torch.Tensor): + if isinstance(batchdata.get(CommonKeys.LABEL), torch.Tensor): return ( batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking), diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 1d76fcaf83..4e1834a625 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -18,9 +18,8 @@ from torch.utils.data.distributed import DistributedSampler from monai.config import IgniteInfo -from monai.data import decollate_batch, rep_scalar_to_batch from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch -from monai.transforms import Transform +from monai.transforms import Decollated, Transform from monai.utils import ensure_tuple, min_version, optional_import from .utils import engine_apply_transform @@ -179,15 +178,16 @@ def set_sampler_epoch(engine: Engine): def _register_decollate(self): """ - Register the decollate operation for batch data, will execure after model forward and loss forward. + Register the decollate operation for batch data, will execute after model forward and loss forward. """ @self.on(IterationEvents.MODEL_COMPLETED) def _decollate_data(engine: Engine) -> None: # replicate the scalar values to make sure all the items have batch dimension, then decollate - engine.state.batch = decollate_batch(rep_scalar_to_batch(engine.state.batch), detach=True) - engine.state.output = decollate_batch(rep_scalar_to_batch(engine.state.output), detach=True) + transform = Decollated(keys=None, detach=True) + engine.state.batch = transform(engine.state.batch) + engine.state.output = transform(engine.state.output) def _register_postprocessing(self, posttrans: Callable): """ diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 39d75064c2..42a716ced0 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -13,6 +13,7 @@ from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver from .confusion_matrix import ConfusionMatrix +from .decollate_batch import DecollateBatch from .earlystop_handler import EarlyStopHandler from .garbage_collector import GarbageCollector from .hausdorff_distance import HausdorffDistance diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 2cc2856193..f1f60abf63 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -78,7 +78,7 @@ def __init__( if load_path is None: raise AssertionError("must provide clear path to load checkpoint.") self.load_path = load_path - if not (load_dict is not None and len(load_dict) > 0): + if load_dict is None or len(load_dict) <= 0: raise AssertionError("must provide target objects to load.") self.logger = logging.getLogger(name) self.load_dict = load_dict diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 96df3c1523..815be87754 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -51,7 +51,7 @@ def __init__( output_dir: if `saver=None`, output CSV file directory. filename: if `saver=None`, name of the saved CSV file name. overwrite: if `saver=None`, whether to overwriting existing file content, if True, - will clear the file before saving. otherwise, will apend new content to the file. + will clear the file before saving. otherwise, will append new content to the file. batch_transform: a callable that is used to extract the `meta_data` dictionary of the input images from `ignite.engine.state.batch`. the purpose is to get the input filenames from the `meta_data` and store with classification results together. diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py new file mode 100644 index 0000000000..4e99fc6f04 --- /dev/null +++ b/monai/handlers/decollate_batch.py @@ -0,0 +1,94 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Optional + +from monai.config import IgniteInfo, KeysCollection +from monai.engines.utils import IterationEvents +from monai.transforms import Decollated +from monai.utils import min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + + +class DecollateBatch: + """ + Ignite handler to execute the `decollate batch` logic for `engine.state.batch` and `engine.state.output`. + Typical usage is to set `decollate=False` in the engine and execute some postprocessing logic first + then decollate the batch, otherwise, engine will decollate batch before the postprocessing. + + Args: + event: expected EVENT to attach the handler, should be "MODEL_COMPLETED" or "ITERATION_COMPLETED". + default to "MODEL_COMPLETED". + detach: whether to detach the tensors. scalars tensors will be detached into number types + instead of torch tensors. + decollate_batch: whether to decollate `engine.state.batch` of ignite engine. + batch_keys: if `decollate_batch=True`, specify the keys of the corresponding items to decollate + in `engine.state.batch`, note that it will delete other keys not specified. if None, + will decollate all the keys. it replicates the scalar values to every item of the decollated list. + decollate_output: whether to decollate `engine.state.output` of ignite engine. + output_keys: if `decollate_output=True`, specify the keys of the corresponding items to decollate + in `engine.state.output`, note that it will delete other keys not specified. if None, + will decollate all the keys. it replicates the scalar values to every item of the decollated list. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + event: str = "MODEL_COMPLETED", + detach: bool = True, + decollate_batch: bool = True, + batch_keys: Optional[KeysCollection] = None, + decollate_output: bool = True, + output_keys: Optional[KeysCollection] = None, + allow_missing_keys: bool = False, + ): + event = event.upper() + if event not in ("MODEL_COMPLETED", "ITERATION_COMPLETED"): + raise ValueError("event should be `MODEL_COMPLETED` or `ITERATION_COMPLETED`.") + self.event = event + + self.batch_transform = ( + Decollated(keys=batch_keys, detach=detach, allow_missing_keys=allow_missing_keys) + if decollate_batch + else None + ) + + self.output_transform = ( + Decollated(keys=output_keys, detach=detach, allow_missing_keys=allow_missing_keys) + if decollate_output + else None + ) + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + if self.event == "MODEL_COMPLETED": + engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) + else: + engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + def __call__(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + if self.batch_transform is not None: + engine.state.batch = self.batch_transform(engine.state.batch) + if self.output_transform is not None: + engine.state.output = self.output_transform(engine.state.output) diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index acfd2eb94e..97b080b244 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -41,7 +41,7 @@ class MetricsSaver: typically, it's some intermediate values in metric computation. for example: mean dice of every channel of every image in the validation dataset. it must contain at least 2 dims: (batch, classes, ...), - if not, will unsequeeze to 2 dims. + if not, will unsqueeze to 2 dims. this arg can be: None, "*" or list of strings. None - don't save any metric_details into files. "*" - save all the existing metric_details in `engine.state.metric_details` dict into separate files. diff --git a/monai/handlers/postprocessing.py b/monai/handlers/postprocessing.py index cb5342456c..05c6bd414d 100644 --- a/monai/handlers/postprocessing.py +++ b/monai/handlers/postprocessing.py @@ -26,24 +26,35 @@ class PostProcessing: """ Ignite handler to execute additional post processing after the post processing in engines. So users can insert other handlers between engine postprocessing and this post processing handler. + If using components from `monai.transforms` as the `transform`, recommend to decollate `engine.state.batch` + and `engine.state.batch` in the engine(set `decollate=True`) or in the `DecollateBatch` handler first. """ - def __init__(self, transform: Callable) -> None: + def __init__(self, transform: Callable, event: str = "MODEL_COMPLETED") -> None: """ Args: transform: callable function to execute on the `engine.state.batch` and `engine.state.output`. can also be composed transforms. + event: expected EVENT to attach the handler, should be "MODEL_COMPLETED" or "ITERATION_COMPLETED". + default to "MODEL_COMPLETED". """ self.transform = transform + event = event.upper() + if event not in ("MODEL_COMPLETED", "ITERATION_COMPLETED"): + raise ValueError("event should be `MODEL_COMPLETED` or `ITERATION_COMPLETED`.") + self.event = event def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) + if self.event == "MODEL_COMPLETED": + engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) + else: + engine.add_event_handler(Events.ITERATION_COMPLETED, self) def __call__(self, engine: Engine) -> None: """ diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index 99ed91c714..d5756074fc 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -167,10 +167,13 @@ def _default_epoch_print(self, engine: Engine) -> None: out_str += self.key_var_format.format(name, value) self.logger.info(out_str) - if hasattr(engine.state, "key_metric_name"): - if hasattr(engine.state, "best_metric") and hasattr(engine.state, "best_metric_epoch"): - out_str = f"Key metric: {engine.state.key_metric_name} " - out_str += f"best value: {engine.state.best_metric} at epoch: {engine.state.best_metric_epoch}" + if ( + hasattr(engine.state, "key_metric_name") + and hasattr(engine.state, "best_metric") + and hasattr(engine.state, "best_metric_epoch") + ): + out_str = f"Key metric: {engine.state.key_metric_name} " + out_str += f"best value: {engine.state.best_metric} at epoch: {engine.state.best_metric_epoch}" self.logger.info(out_str) def _default_iteration_print(self, engine: Engine) -> None: @@ -201,18 +204,17 @@ def _default_iteration_print(self, engine: Engine) -> None: ) continue # not printing multi dimensional output out_str += self.key_var_format.format(name, value.item() if isinstance(value, torch.Tensor) else value) + elif is_scalar(loss): # not printing multi dimensional output + out_str += self.key_var_format.format( + self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss + ) else: - if is_scalar(loss): # not printing multi dimensional output - out_str += self.key_var_format.format( - self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss - ) - else: - warnings.warn( - "ignoring non-scalar output in StatsHandler," - " make sure `output_transform(engine.state.output)` returns" - " a scalar or a dictionary of key and scalar pairs to avoid this warning." - " {}".format(type(loss)) - ) + warnings.warn( + "ignoring non-scalar output in StatsHandler," + " make sure `output_transform(engine.state.output)` returns" + " a scalar or a dictionary of key and scalar pairs to avoid this warning." + " {}".format(type(loss)) + ) if not out_str: return # no value to print diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index 04b2878701..4cf234241d 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -124,7 +124,7 @@ def __call__(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ if not isinstance(engine.state.batch, list) or not isinstance(engine.state.output, list): - warnings.warn("inverter requires `engine.state.batch` and `engine.state.outout` to be lists.") + warnings.warn("inverter requires `engine.state.batch` and `engine.state.output` to be lists.") else: for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): # combine `batch` and `output` to temporarily act as 1 dict for postprocessing diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 94c2f3585c..fde7bc4045 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -111,17 +111,16 @@ def string_list_all_gather(strings: List[str]) -> List[str]: max_len = max(all_lens) # pad the item to make sure the same length if length < max_len: - strings = strings + ["" for _ in range(max_len - length)] - - if get_torch_version_tuple() > (1, 6, 0): - for s in strings: - gathered = idist.all_gather(s) - for i, g in enumerate(gathered): - if len(g) > 0: - result[i].append(g) - else: + strings += ["" for _ in range(max_len - length)] + + if get_torch_version_tuple() <= (1, 6, 0): raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.") + for s in strings: + gathered = idist.all_gather(s) + for i, g in enumerate(gathered): + if len(g) > 0: + result[i].append(g) return [i for k in result for i in k] @@ -214,12 +213,12 @@ class mean median max 5percentile 95percentile notnans ops = tuple(supported_ops.keys()) def _compute_op(op: str, d: np.ndarray): - if op.endswith("percentile"): - threshold = int(op.split("percentile")[0]) - return supported_ops["90percentile"]((d, threshold)) - else: + if not op.endswith("percentile"): return supported_ops[op](d) + threshold = int(op.split("percentile")[0]) + return supported_ops["90percentile"]((d, threshold)) + with open(os.path.join(save_dir, f"{k}_summary.csv"), "w") as f: f.write(f"class{deli}{deli.join(ops)}\n") for i, c in enumerate(np.transpose(v)): @@ -249,7 +248,7 @@ def from_engine(keys: KeysCollection, first: bool = False): Args: keys: specified keys to extract data from dictionary or decollated list of dictionaries. - first: whether only extract sepcified keys from the first item if input data is a list of dictionaries, + first: whether only extract specified keys from the first item if input data is a list of dictionaries, it's used to extract the scalar data which doesn't have batch dim and was replicated into every dictionary when decollating, like `loss`, etc. diff --git a/monai/losses/deform.py b/monai/losses/deform.py index acba229121..d96fa1440a 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -96,9 +96,7 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: energy = torch.mean(energy) # the batch and channel average elif self.reduction == LossReduction.SUM.value: energy = torch.sum(energy) # sum over the batch and channel dims - elif self.reduction == LossReduction.NONE.value: - pass # returns [N, n_classes] losses - else: + elif self.reduction != LossReduction.NONE.value: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return energy diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 385a0c6426..7d56d8f436 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -108,7 +108,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. Raises: - AssertionError: When input and target (after one hot transform if setted) + AssertionError: When input and target (after one hot transform if set) have different shapes. ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. @@ -169,9 +169,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims - elif self.reduction == LossReduction.NONE.value: - pass # returns [N, n_classes] losses - else: + elif self.reduction != LossReduction.NONE.value: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return f @@ -346,9 +344,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims - elif self.reduction == LossReduction.NONE.value: - pass # returns [N, n_classes] losses - else: + elif self.reduction != LossReduction.NONE.value: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return f @@ -462,7 +458,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # Compute the values of alpha to use alpha = self._compute_alpha_generalized_true_positives(flat_target) - # Compute the nemerator and denominator of the generalized Wasserstein Dice loss + # Compute the numerator and denominator of the generalized Wasserstein Dice loss if self.alpha_mode == "GDL": # use GDL-style alpha weights (i.e. normalize by the volume of each class) # contrary to the original definition we also use alpha in the "generalized all error". @@ -483,9 +479,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: wass_dice_loss = torch.mean(wass_dice_loss) # the batch and channel average elif self.reduction == LossReduction.SUM.value: wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims - elif self.reduction == LossReduction.NONE.value: - pass # returns [N, n_classes] losses - else: + elif self.reduction != LossReduction.NONE.value: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return wass_dice_loss @@ -543,12 +537,10 @@ def _compute_generalized_true_positive( flat_target_extended = torch.unsqueeze(flat_target, dim=1) alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) - # Compute the generalized true positive as in eq. 9 - generalized_true_pos = torch.sum( + return torch.sum( alpha_extended * (1.0 - wasserstein_distance_map), dim=[1, 2], ) - return generalized_true_pos def _compute_denominator( self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor @@ -565,12 +557,10 @@ def _compute_denominator( flat_target_extended = torch.unsqueeze(flat_target, dim=1) alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) - # Compute the generalized true positive as in eq. 9 - generalized_true_pos = torch.sum( + return torch.sum( alpha_extended * (2.0 - wasserstein_distance_map), dim=[1, 2], ) - return generalized_true_pos def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -> torch.Tensor: """ diff --git a/monai/losses/multi_scale.py b/monai/losses/multi_scale.py index af23e03440..6f9326420b 100644 --- a/monai/losses/multi_scale.py +++ b/monai/losses/multi_scale.py @@ -21,8 +21,12 @@ def make_gaussian_kernel(sigma: int) -> torch.Tensor: if sigma <= 0: raise ValueError(f"expecting positive sigma, got sigma={sigma}") - kernel = gaussian_1d(sigma=torch.tensor(sigma), truncated=3, approx="sampled", normalize=False) - return kernel + return gaussian_1d( + sigma=torch.tensor(sigma), + truncated=3, + approx="sampled", + normalize=False, + ) def make_cauchy_kernel(sigma: int) -> torch.Tensor: @@ -92,9 +96,7 @@ def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: loss = torch.mean(loss) # the batch and channel average elif self.reduction == LossReduction.SUM.value: loss = torch.sum(loss) # sum over the batch and channel dims - elif self.reduction == LossReduction.NONE.value: - pass # returns [N, n_classes] losses - else: + elif self.reduction != LossReduction.NONE.value: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return loss diff --git a/monai/losses/spatial_mask.py b/monai/losses/spatial_mask.py index 4bad5bf73b..387300e507 100644 --- a/monai/losses/spatial_mask.py +++ b/monai/losses/spatial_mask.py @@ -36,11 +36,7 @@ def __init__(self, loss: Union[Callable, _Loss], *loss_args, **loss_kwargs) -> N loss_kwargs: keyword arguments to the loss function's constructor if `loss` is a class. """ super().__init__() - if inspect.isclass(loss): # call the loss function's constructor if it's a class - self.loss = loss(*loss_args, **loss_kwargs) - else: - self.loss = loss # loss is a callable loss function instance. - + self.loss = loss(*loss_args, **loss_kwargs) if inspect.isclass(loss) else loss if not callable(self.loss): raise ValueError("The loss function is not callable.") @@ -57,7 +53,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: Optional[torc if input.dim() != mask.dim(): warnings.warn(f"Dim of input ({input.shape}) is different from mask ({mask.shape}).") - if not (input.shape[0] == mask.shape[0] or mask.shape[0] == 1): + if input.shape[0] != mask.shape[0] and mask.shape[0] != 1: raise ValueError(f"Batch size of mask ({mask.shape}) must be one or equal to input ({input.shape}).") if target.dim() > 1: if mask.shape[1] != 1: diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 8ebf547a1e..1bfd85a83e 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -142,5 +142,8 @@ def compute_meandice( y_pred_o = torch.sum(y_pred, dim=reduce_axis) denominator = y_o + y_pred_o - f = torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device)) - return f # returns array of Dice with shape: [batch, n_classes] + return torch.where( + y_o > 0, + (2.0 * intersection) / denominator, + torch.tensor(float("nan"), device=y_o.device), + ) diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py index ffc848b36d..bb4aa7c343 100644 --- a/monai/metrics/metric.py +++ b/monai/metrics/metric.py @@ -38,7 +38,7 @@ class IterationMetric(Metric): """ Base class of Metrics interface for computation on a batch of tensors, usually the data of 1 iteration. `__call__` is supposed to compute independent logic for several samples of `y_pred` and `y`(optional). - Ususally, subclass only needs to implement the `_compute_tensor` function for computation process. + Usually, subclass only needs to implement the `_compute_tensor` function for computation process. The input data shape should be `list of channel-first tensors` or a `batch-first tensor`. """ @@ -85,7 +85,7 @@ def _compute_list(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): # concat the list of results if isinstance(ret[0], torch.Tensor): ret = torch.cat(ret, dim=0) - elif isinstance(ret[0], (list, tuple)) and all([isinstance(i, torch.Tensor) for i in ret[0]]): + elif isinstance(ret[0], (list, tuple)) and all(isinstance(i, torch.Tensor) for i in ret[0]): # if _compute_tensor() returned not only 1 Tensor, concat them separately ret = [torch.cat([k[i] for k in ret], dim=0) for i in range(len(ret[0]))] @@ -119,7 +119,7 @@ class Cumulative(ABC): cum.add(x, y) cum.add(a, b) cum.add(c, d) - cum.agrregate() + cum.aggregate() result = cum.get_buffer() cum.reset() diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index 044f99f1a5..a2a2f0853d 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -107,8 +107,7 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor y_pred = y_pred.float() y = y.float() - mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func) - return mse_out + return compute_mean_error_metrics(y_pred, y, func=self.sq_func) class MAEMetric(RegressionMetric): @@ -142,8 +141,7 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor y_pred = y_pred.float() y = y.float() - mae_out = compute_mean_error_metrics(y_pred, y, func=self.abs_func) - return mae_out + return compute_mean_error_metrics(y_pred, y, func=self.abs_func) class RMSEMetric(RegressionMetric): @@ -179,8 +177,7 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor y = y.float() mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func) - rmse_out = torch.sqrt(mse_out) - return rmse_out + return torch.sqrt(mse_out) class PSNRMetric(RegressionMetric): @@ -223,14 +220,11 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> Any: y = y.float() mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func) - psnr_val = 20 * math.log10(self.max_val) - 10 * torch.log10(mse_out) - return psnr_val + return 20 * math.log10(self.max_val) - 10 * torch.log10(mse_out) def compute_mean_error_metrics(y_pred: torch.Tensor, y: torch.Tensor, func) -> torch.Tensor: # reducing in only channel + spatial dimensions (not batch) - # reducion of batch handled inside __call__() using do_metric_reduction() in respective calling class + # reduction of batch handled inside __call__() using do_metric_reduction() in respective calling class flt = partial(torch.flatten, start_dim=1) - error_metric = torch.mean(flt(func(y - y_pred)), dim=-1, keepdim=True) - - return error_metric + return torch.mean(flt(func(y - y_pred)), dim=-1, keepdim=True) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index be7d130dce..a6fcc50be1 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -201,7 +201,7 @@ def get_surface_distance( return np.asarray(dis[seg_gt]) if distance_metric == "euclidean": dis = distance_transform_edt(~seg_gt) - elif distance_metric in ["chessboard", "taxicab"]: + elif distance_metric in {"chessboard", "taxicab"}: dis = distance_transform_cdt(~seg_gt, metric=distance_metric) else: raise ValueError(f"distance_metric {distance_metric} is not implemented.") diff --git a/monai/networks/blocks/crf.py b/monai/networks/blocks/crf.py index 219f1a70e2..49ff5bcd04 100644 --- a/monai/networks/blocks/crf.py +++ b/monai/networks/blocks/crf.py @@ -54,7 +54,8 @@ def __init__( bilateral_color_sigma: standard deviation in color space for the bilateral term. gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term. update_factor: determines the magnitude of each update. - compatibility_matrix: a matrix describing class compatibility, should be NxN where N is the numer of classes. + compatibility_matrix: a matrix describing class compatibility, + should be NxN where N is the number of classes. """ super(CRF, self).__init__() self.iterations = iterations @@ -92,11 +93,11 @@ def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor): for _ in range(self.iterations): # message passing step for both kernels - bliateral_output = PHLFilter.apply(output_tensor, bilateral_features) + bilateral_output = PHLFilter.apply(output_tensor, bilateral_features) gaussian_output = PHLFilter.apply(output_tensor, gaussian_features) # combining filter outputs - combined_output = self.bilateral_weight * bliateral_output + self.gaussian_weight * gaussian_output + combined_output = self.bilateral_weight * bilateral_output + self.gaussian_weight * gaussian_output # optionally running a compatibility transform if self.compatibility_matrix is not None: diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index 975c2e15bb..9bee4c596e 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -58,5 +58,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: Tensor in shape (batch, 2*channel, spatial_1[, spatial_2, ...]). """ - x_d = torch.cat([self.max_pool(x), self.avg_pool(x)], dim=1) - return x_d + return torch.cat([self.max_pool(x), self.avg_pool(x)], dim=1) diff --git a/monai/networks/blocks/dynunet_block.py b/monai/networks/blocks/dynunet_block.py index ed4f3a2cbb..bb654d841c 100644 --- a/monai/networks/blocks/dynunet_block.py +++ b/monai/networks/blocks/dynunet_block.py @@ -213,8 +213,7 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int): ) def forward(self, inp): - out = self.conv(inp) - return out + return self.conv(inp) def get_conv_layer( diff --git a/monai/networks/blocks/dynunet_block_v1.py b/monai/networks/blocks/dynunet_block_v1.py index d24f7dfab0..d5d9bbf3dc 100644 --- a/monai/networks/blocks/dynunet_block_v1.py +++ b/monai/networks/blocks/dynunet_block_v1.py @@ -143,10 +143,8 @@ def __init__( def _get_norm_layer(spatial_dims: int, out_channels: int, norm_name: str, num_groups: int = 16): if norm_name not in ["batch", "instance", "group"]: raise ValueError(f"Unsupported normalization mode: {norm_name}") - if norm_name == "group": - if out_channels % num_groups != 0: - raise AssertionError("out_channels should be divisible by num_groups.") - norm = Norm[norm_name, spatial_dims](num_groups=num_groups, num_channels=out_channels, affine=True) - else: - norm = Norm[norm_name, spatial_dims](out_channels, affine=True) - return norm + if norm_name != "group": + return Norm[norm_name, spatial_dims](out_channels, affine=True) + if out_channels % num_groups != 0: + raise AssertionError("out_channels should be divisible by num_groups.") + return Norm[norm_name, spatial_dims](num_groups=num_groups, num_channels=out_channels, affine=True) diff --git a/monai/networks/blocks/fcn.py b/monai/networks/blocks/fcn.py index c7cd7cca30..aa6d69fad0 100644 --- a/monai/networks/blocks/fcn.py +++ b/monai/networks/blocks/fcn.py @@ -91,8 +91,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.relu(x) x = self.conv2(x) - out = residual + x - return out + return residual + x class FCN(nn.Module): @@ -191,7 +190,7 @@ def forward(self, x: torch.Tensor): fs2 = self.refine7(self.up_conv(fs1) + gcfm3) fs3 = self.refine8(self.up_conv(fs2) + gcfm4) fs4 = self.refine9(self.up_conv(fs3) + gcfm5) - out = self.refine10(self.up_conv(fs4)) + return self.refine10(self.up_conv(fs4)) else: fs1 = self.refine6( F.interpolate(gcfm1, fm3.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm2 @@ -203,8 +202,14 @@ def forward(self, x: torch.Tensor): fs4 = self.refine9( F.interpolate(fs3, conv_x.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm5 ) - out = self.refine10(F.interpolate(fs4, org_input.size()[2:], mode=self.upsample_mode, align_corners=True)) - return out + return self.refine10( + F.interpolate( + fs4, + org_input.size()[2:], + mode=self.upsample_mode, + align_corners=True, + ) + ) class MCFCN(FCN): @@ -253,5 +258,4 @@ def forward(self, x: torch.Tensor): x: in shape (batch, in_channels, spatial_1, spatial_2). """ x = self.init_proj(x) - out = super(MCFCN, self).forward(x) - return out + return super(MCFCN, self).forward(x) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index f7e9a0d243..647f2648c8 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -154,7 +154,7 @@ class ResNet(nn.Module): Args: block: which ResNet block to use, either Basic or Bottleneck. layers: how many layers to use. - block_inplanes: determine the size of planes at each step. Also tuneable with widen_factor. + block_inplanes: determine the size of planes at each step. Also tunable with widen_factor. spatial_dims: number of spatial dimensions of the input image. n_input_channels: number of input channels for first convolutional layer. conv1_t_size: size of first convolution layer, determines kernel and padding. diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 7fa8d6600b..b380f7d42a 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -161,7 +161,7 @@ def __call__(self, input_): def inverse(self, data): invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] - if len(invertible_transforms) == 0: + if not invertible_transforms: warnings.warn("inverse has been called but no invertible transforms have been supplied") # loop backwards over transforms diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index f636bdae4e..9f9f880565 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -137,7 +137,7 @@ class SpatialPad(TorchOrNumpyTransform): (no padding). for example: if the spatial size of input data is [30, 30, 30] and `spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30]. method: {``"symmetric"``, ``"end"``} - Pad image symmetric on every side or only pad at the end sides. Defaults to ``"symmetric"``. + Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function. Defaults to ``"constant"``. @@ -1043,6 +1043,10 @@ class ResizeWithPadOrCrop(Transform): ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function for padding. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + method: {``"symmetric"``, ``"end"``} + Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ @@ -1050,8 +1054,10 @@ def __init__( self, spatial_size: Union[Sequence[int], int], mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + method: Union[Method, str] = Method.SYMMETRIC, + **np_kwargs, ): - self.padder = SpatialPad(spatial_size=spatial_size, mode=mode) + self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **np_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index 336f48593b..08820149e9 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -86,7 +86,7 @@ def __call__(self, batch: Any): break max_shapes.append(elem[key_or_idx].shape[1:]) # len > 0 if objects were arrays, else skip as no padding to be done - if len(max_shapes) == 0: + if not max_shapes: continue max_shape = np.array(max_shapes).max(axis=0) # If all same size, skip @@ -121,9 +121,9 @@ def inverse(data: dict) -> DataObjects.Dict: raise RuntimeError("Inverse can only currently be applied on dictionaries.") d = deepcopy(data) - for key in d.keys(): + for key in d: transform_key = str(key) + InverseKeys.KEY_SUFFIX - if transform_key in d.keys(): + if transform_key in d: transform = d[transform_key][-1] if transform[InverseKeys.CLASS_NAME] == PadListDataCollate.__name__: d[key] = CenterSpatialCrop(transform["orig_size"])(d[key]) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 70bb6beab1..89be3cf26e 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -125,7 +125,7 @@ def __init__( for example: if the spatial size of input data is [30, 30, 30] and `spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30]. method: {``"symmetric"``, ``"end"``} - Pad image symmetric on every side or only pad at the end sides. Defaults to ``"symmetric"``. + Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function. Defaults to ``"constant"``. @@ -1341,6 +1341,10 @@ class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + method: {``"symmetric"``, ``"end"``} + Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ @@ -1350,10 +1354,12 @@ def __init__( spatial_size: Union[Sequence[int], int], mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, + method: Union[Method, str] = Method.SYMMETRIC, + **np_kwargs, ) -> None: super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size) + self.padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **np_kwargs) def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index ba5d18dfb9..d7245760a2 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -481,7 +481,7 @@ def _generate_random_field( coeff_mat[np_pts[:, 0], np_pts[:, 1], np_pts[:, 2]] = coeff field = np.polynomial.legendre.leggrid3d(coords[0], coords[1], coords[2], coeff_mat) else: - raise NotImplementedError("only supoprts 2D or 3D fields") + raise NotImplementedError("only supports 2D or 3D fields") return field def randomize(self, data: np.ndarray) -> None: @@ -1515,7 +1515,7 @@ class RandKSpaceSpikeNoise(RandomizableTransform): channels at once, or channel-wise if ``channel_wise = True``. intensity_range: pass a tuple (a, b) to sample the log-intensity from the interval (a, b) - uniformly for all channels. Or pass sequence of intevals + uniformly for all channels. Or pass sequence of intervals ((a0, b0), (a1, b1), ...) to sample for each respective channel. In the second case, the number of 2-tuples must match the number of channels. @@ -1576,7 +1576,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, intensity_range = self._make_sequence(x) self._randomize(x, intensity_range) - # build/appy transform only if there are spike locations + # build/apply transform only if there are spike locations if self.sampled_locs: transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity, self.as_tensor_output) return transform(x) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index fc5ec479ad..22525b41b5 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1220,7 +1220,7 @@ class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): img_intensity_range: Intensity range to sample for ``"image"`` key. Pass a tuple `(a, b)` to sample the log-intensity from the interval `(a, b)` uniformly for all - channels. Or pass sequence of intevals `((a0, b0), (a1, b1), ...)` + channels. Or pass sequence of intervals `((a0, b0), (a1, b1), ...)` to sample for each respective channel. In the second case, the number of 2-tuples must match the number of channels. Default ranges is `(0.95x, 1.10x)` where `x` is the mean diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index c6dad2fcd0..d9c6790840 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -10,13 +10,14 @@ # limitations under the License. import warnings -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence, Union from torch.utils.data import Dataset from torch.utils.data.dataloader import DataLoader as TorchDataLoader +from monai.config import KeysCollection from monai.data.dataloader import DataLoader -from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate +from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate, rep_scalar_to_batch from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Transform @@ -103,18 +104,39 @@ def __call__(self, data: Dict[str, Any]) -> Any: class Decollated(MapTransform): """ - Decollate a batch of data. - Note that unlike most MapTransforms, this will decollate all data, so keys are not needed. + Decollate a batch of data, if input a dictionary, it can also support to only decollate specified keys. + Note that unlike most MapTransforms, it will delete other keys not specified and if keys=None, will decollate + all the data in the input. + And it replicates the scalar values to every item of the decollated list. Args: + keys: keys of the corresponding items to decollate, note that it will delete other keys not specified. + if None, will decollate all the keys. see also: :py:class:`monai.transforms.compose.MapTransform`. detach: whether to detach the tensors. Scalars tensors will be detached into number types instead of torch tensors. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys="", detach: bool = True) -> None: - super().__init__(keys=keys) + def __init__( + self, + keys: Optional[KeysCollection] = None, + detach: bool = True, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) self.detach = detach - def __call__(self, data: dict): - return decollate_batch(data, detach=self.detach) + def __call__(self, data: Union[Dict, List]): + d: Union[Dict, List] + if len(self.keys) == 1 and self.keys[0] is None: + # it doesn't support `None` as the key + d = data + else: + if not isinstance(data, dict): + raise TypeError("input data is not a dictionary, but specified keys to decollate.") + d = {} + for key in self.key_iterator(data): + d[key] = data[key] + + return decollate_batch(rep_scalar_to_batch(d), detach=self.detach) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index d1cdafc533..7da2b4abee 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -575,7 +575,7 @@ def __init__( the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - will extract the filename of input image to save classifcation results. + will extract the filename of input image to save classification results. meta_key_postfix: `key_{postfix}` was used to store the metadata in `LoadImaged`. so need the key to extract the metadata of input image, like filename, etc. default is `meta_dict`. for example, for data with key `image`, the metadata by default is in `image_meta_dict`. @@ -586,7 +586,7 @@ def __init__( output_dir: if `saver=None`, specify the directory to save the CSV file. filename: if `saver=None`, specify the name of the saved CSV file. overwrite: if `saver=None`, indicate whether to overwriting existing CSV file content, if True, - will clear the file before saving. otherwise, will apend new content to the CSV file. + will clear the file before saving. otherwise, will append new content to the CSV file. flush: if `saver=None`, indicate whether to write the cache data to CSV file immediately in this transform and clear the cache. default to True. If False, may need user to call `saver.finalize()` manually or use `ClassificationSaver` handler. diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index ee6a77d699..1cf3474837 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -106,7 +106,7 @@ def __init__( of pixdim sequence is longer than image spatial dimensions, will ignore the longer part, if shorter, will pad with `1.0`. if the components of the `pixdim` are non-positive values, the transform will use the - corresponding components of the origial pixdim, which is computed from the `affine` + corresponding components of the original pixdim, which is computed from the `affine` matrix of input image. diagonal: whether to resample the input to have a diagonal affine matrix. If True, the input data is resampled to the following affine:: @@ -166,7 +166,7 @@ def __call__( the output data type is always ``np.float32``. output_spatial_shape: specify the shape of the output data_array. This is typically useful for the inverse of `Spacingd` where sometimes we could not compute the exact shape due to the quantization - error with the affines. + error with the affine. Raises: ValueError: When ``data_array`` has no spatial dimensions. diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index e341c93bf7..46bcae6891 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -151,7 +151,7 @@ def __init__( of pixdim sequence is longer than image spatial dimensions, will ignore the longer part, if shorter, will pad with `1.0`. if the components of the `pixdim` are non-positive values, the transform will use the - corresponding components of the origial pixdim, which is computed from the `affine` + corresponding components of the original pixdim, which is computed from the `affine` matrix of input image. diagonal: whether to resample the input to have a diagonal affine matrix. If True, the input data is resampled to the following affine:: diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 631cac0fbd..1a62b31a37 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -462,14 +462,13 @@ def key_iterator( extra_iterables: anything else to be iterated through """ # if no extra iterables given, create a dummy list of Nones - ex_iters = extra_iterables if extra_iterables else [[None] * len(self.keys)] + ex_iters = extra_iterables or [[None] * len(self.keys)] # loop over keys and any extra iterables _ex_iters: List[Any] for key, *_ex_iters in zip(self.keys, *ex_iters): # all normal, yield (what we yield depends on whether extra iterables were given) - if key in data.keys(): + if key in data: yield (key,) + tuple(_ex_iters) if extra_iterables else key - # if missing keys not allowed, raise elif not self.allow_missing_keys: raise KeyError(f"Key was missing ({key}) and allow_missing_keys==False") diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 78f0f142ce..dee3647258 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -368,7 +368,7 @@ def __call__(self, data): Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will ensure Tensor, Numpy array, float, int, bool as Tensors or numpy arrays, strings and - objects keep the original. for dictionay, list or tuple, ensure every item as expected type + objects keep the original. for dictionary, list or tuple, ensure every item as expected type if applicable. """ diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 28b0d365a8..9d38d37ca9 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -853,7 +853,8 @@ def map_spatial_axes( """ if spatial_axes is None: - spatial_axes_ = list(range(1, img_ndim) if channel_first else range(0, img_ndim - 1)) + spatial_axes_ = list(range(1, img_ndim) if channel_first else range(img_ndim - 1)) + else: spatial_axes_ = [] for a in ensure_tuple(spatial_axes): @@ -974,7 +975,7 @@ def convert_to_tensor(data): Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to Tensors, strings and objects keep the original. - for dictionay, list or tuple, convert every item to a Tensor if applicable. + for dictionary, list or tuple, convert every item to a Tensor if applicable. """ if isinstance(data, torch.Tensor): @@ -993,7 +994,7 @@ def convert_to_tensor(data): elif isinstance(data, list): return [convert_to_tensor(i) for i in data] elif isinstance(data, tuple): - return tuple([convert_to_tensor(i) for i in data]) + return tuple(convert_to_tensor(i) for i in data) return data @@ -1006,7 +1007,7 @@ def convert_to_numpy(data): Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original. - for dictionay, list or tuple, convert every item to a numpy array if applicable. + for dictionary, list or tuple, convert every item to a numpy array if applicable. """ if isinstance(data, torch.Tensor): @@ -1035,7 +1036,7 @@ def tensor_to_numpy(data): Args: data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. - will convert the Tensor data to numpy array, others keep the original. for dictionay, list or tuple, + will convert the Tensor data to numpy array, others keep the original. for dictionary, list or tuple, convert every Tensor item to numpy array if applicable. """ @@ -1048,6 +1049,6 @@ def tensor_to_numpy(data): elif isinstance(data, list): return [tensor_to_numpy(i) for i in data] elif isinstance(data, tuple): - return tuple([tensor_to_numpy(i) for i in data]) + return tuple(tensor_to_numpy(i) for i in data) return data diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 255154a84b..5cb365e088 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -85,7 +85,7 @@ def _torch_all_gather(data: torch.Tensor) -> List[torch.Tensor]: # all gather across all processes output = [torch.zeros_like(data) for _ in range(dist.get_world_size())] dist.all_gather(output, data) - # remove the padding items, if all the input data doesn't have batch dim, suqeeze the first dim + # remove the padding items, if all the input data doesn't have batch dim, squeeze the first dim return [(o.squeeze(0) if ndims == 0 else o[:l, ...]).to(orig_device) for o, l in zip(output, all_lens_)] def _ignite_all_gather(data: torch.Tensor) -> List[torch.Tensor]: diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index 78136f6404..b86f9f442c 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -200,12 +200,11 @@ def plot_engine_status( image = image_fn(k, v) if image is not None: imagemap[f"{k}_{i}"] = image - else: + elif isinstance(s, torch.Tensor): label = "Batch" if src is engine.state.batch else "Output" - if isinstance(s, torch.Tensor): - image = image_fn(label, s) - if image is not None: - imagemap[f"{label}_{i}"] = image + image = image_fn(label, s) + if image is not None: + imagemap[f"{label}_{i}"] = image axes = plot_metric_images(fig, title, graphmap, imagemap, yscale, avg_keys, window_fraction) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 43a774952f..c6ca5fbe44 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -181,9 +181,7 @@ def fall_back_tuple( def is_scalar_tensor(val: Any) -> bool: - if isinstance(val, torch.Tensor) and val.ndim == 0: - return True - return False + return isinstance(val, torch.Tensor) and val.ndim == 0 def is_scalar(val: Any) -> bool: @@ -202,7 +200,7 @@ def progress_bar(index: int, count: int, desc: Optional[str] = None, bar_len: in bar_len: the total length of the bar on screen, default is 30 char. newline: whether to print in a new line for every index. """ - end = "\r" if newline is False else "\r\n" + end = "\r" if not newline else "\r\n" filled_len = int(bar_len * index // count) bar = f"{desc} " if desc is not None else "" bar += "[" + "=" * filled_len + " " * (bar_len - filled_len) + "]" diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index 65a6118670..94943a8c37 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -58,9 +58,8 @@ def __init__( if self.cache_dir is None: self.cache_dir = tempfile.gettempdir() - else: - if not os.path.isdir(self.cache_dir): - raise ValueError("Given `cache_dir` is not a valid directory.") + elif not os.path.isdir(self.cache_dir): + raise ValueError("Given `cache_dir` is not a valid directory.") self.cached: Dict[str, str] = {} diff --git a/tests/min_tests.py b/tests/min_tests.py index a3f140b856..1cd54f35d0 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -134,6 +134,7 @@ def run_testsuit(): "test_unetr", "test_unetr_block", "test_vit", + "test_handler_decollate_batch", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 235862ff96..296eaa8f75 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -211,6 +211,42 @@ def test_dict_examples(self): out = decollate_batch(test_case, detach=False) self.assertEqual(out[0]["out"], "test") + def test_decollated(self): + test_case = { + "image": torch.tensor([[[1, 2]], [[3, 4]]]), + "meta": {"out": ["test", "test"]}, + "image_meta_dict": {"scl_slope": torch.Tensor((0.0, 0.0))}, + "loss": 0.85, + } + transform = Decollated(keys=["meta", "image_meta_dict"], detach=False) + out = transform(test_case) + self.assertFalse("loss" in out) + self.assertEqual(out[0]["meta"]["out"], "test") + self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0) + self.assertTrue(isinstance(out[0]["image_meta_dict"]["scl_slope"], torch.Tensor)) + # decollate all data with keys=None + transform = Decollated(keys=None, detach=True) + out = transform(test_case) + self.assertEqual(out[1]["loss"], 0.85) + self.assertEqual(out[0]["meta"]["out"], "test") + self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0) + self.assertTrue(isinstance(out[0]["image_meta_dict"]["scl_slope"], float)) + + # test list input + test_case = [ + torch.tensor([[[1, 2]], [[3, 4]]]), + {"out": ["test", "test"]}, + {"scl_slope": torch.Tensor((0.0, 0.0))}, + 0.85, + ] + transform = Decollated(keys=None, detach=False) + out = transform(test_case) + # the 4th item in the list is scalar loss value + self.assertEqual(out[1][3], 0.85) + self.assertEqual(out[0][1]["out"], "test") + self.assertEqual(out[0][2]["scl_slope"], 0.0) + self.assertTrue(isinstance(out[0][2]["scl_slope"], torch.Tensor)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py new file mode 100644 index 0000000000..bc74cf5328 --- /dev/null +++ b/tests/test_handler_decollate_batch.py @@ -0,0 +1,63 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.engines import SupervisedEvaluator +from monai.handlers import DecollateBatch, PostProcessing +from monai.transforms import Activationsd, AsDiscreted, Compose, CopyItemsd + + +class TestHandlerDecollateBatch(unittest.TestCase): + def test_compute(self): + data = [ + {"image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"]}, + {"image": torch.tensor([[[[6.0], [8.0]]]]), "filename": ["test2"]}, + ] + + handlers = [ + DecollateBatch(event="MODEL_COMPLETED"), + PostProcessing( + transform=Compose( + [ + Activationsd(keys="pred", sigmoid=True), + CopyItemsd(keys="filename", times=1, names="filename_bak"), + AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, n_classes=2), + ] + ) + ), + ] + # set up engine, PostProcessing handler works together with postprocessing transforms of engine + engine = SupervisedEvaluator( + device=torch.device("cpu:0"), + val_data_loader=data, + epoch_length=2, + network=torch.nn.PReLU(), + # set decollate=False and execute some postprocessing first, then decollate in handlers + postprocessing=lambda x: dict(pred=x["pred"] + 1.0), + decollate=False, + val_handlers=handlers, + ) + engine.run() + + expected = torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]) + + for o, e in zip(engine.state.output, expected): + torch.testing.assert_allclose(o["pred"], e) + filename = o.get("filename_bak") + if filename is not None: + self.assertEqual(filename, "test2") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py index 589adfb35d..552cde9eb1 100644 --- a/tests/test_handler_post_processing.py +++ b/tests/test_handler_post_processing.py @@ -28,7 +28,8 @@ CopyItemsd(keys="filename", times=1, names="filename_bak"), AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, n_classes=2), ] - ) + ), + "event": "iteration_completed", }, True, torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]), diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index 3e3aed709f..fbdb651297 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -191,12 +191,15 @@ def test_shape(self, transform, expected_shape, kwargs=None): "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), }, ] - dataset_postcached.set_data(data=test_data_new) # test new exchanged cache content if transform is None: + dataset_postcached.set_data(data=test_data_new) self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) + else: + with self.assertRaises(RuntimeError): + dataset_postcached.set_data(data=test_data_new) # filename list updated, files do not exist @skip_if_windows diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index 53fb0d3002..46f1fc86cc 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -23,7 +23,7 @@ (3, 15, 8, 8), ], [ - {"spatial_size": [15, 4, 8], "mode": "constant"}, + {"spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "constant_values": 1}, (3, 8, 8, 4), (3, 15, 4, 8), ], diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index 8cbb31b5a6..32a62a9e16 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -23,7 +23,7 @@ (3, 15, 8, 8), ], [ - {"keys": "img", "spatial_size": [15, 4, 8], "mode": "constant"}, + {"keys": "img", "spatial_size": [15, 4, 8], "mode": "constant", "method": "end", "constant_values": 1}, {"img": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 8), ], From 06f7a6e3baa7a5ae10456d9111fab612802d3fce Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 14 Jul 2021 11:31:32 +0100 Subject: [PATCH 099/176] fix torch seed Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 1a62b31a37..3ff12b7522 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -218,7 +218,7 @@ def set_random_state( _seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed _seed = _seed % MAX_SEED self.R = np.random.RandomState(_seed) - self.R_torch = torch.manual_seed(int(_seed)) + self.R_torch = torch.Generator().manual_seed(int(_seed)) return self if state is not None: @@ -228,6 +228,7 @@ def set_random_state( return self self.R = np.random.RandomState() + self.R_torch = None return self def randomize(self, data: Any) -> None: From f7aeb070877d9d9d36573315de88620f40d64114 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 14 Jul 2021 12:29:04 +0100 Subject: [PATCH 100/176] remove duplicate class Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/dictionary.py | 179 ------------------------- 1 file changed, 179 deletions(-) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index ea0793a9c4..89be3cf26e 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -1326,185 +1326,6 @@ def inverse(self, data: DataObjects.Mapping) -> DataObjects.Dict: return d -class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): - """ - Dictionary-based version :py:class:`monai.transforms.RandCropByLabelClasses`. - Crop random fixed sized regions with the center being a class based on the specified ratios of every class. - The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the - cropped images. For example, crop two (3 x 3) arrays from (5 x 5) array with `ratios=[1, 2, 3, 1]`:: - - cropper = RandCropByLabelClassesd( - keys=["image", "label"], - label_key="label", - spatial_size=[3, 3], - ratios=[1, 2, 3, 1], - num_classes=4, - num_samples=2, - ) - data = { - "image": np.array([ - [[0.0, 0.3, 0.4, 0.2, 0.0], - [0.0, 0.1, 0.2, 0.1, 0.4], - [0.0, 0.3, 0.5, 0.2, 0.0], - [0.1, 0.2, 0.1, 0.1, 0.0], - [0.0, 0.1, 0.2, 0.1, 0.0]] - ]), - "label": np.array([ - [[0, 0, 0, 0, 0], - [0, 1, 2, 1, 0], - [0, 1, 3, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]] - ]), - } - result = cropper(data) - - The 2 randomly cropped samples of `label` can be: - [[0, 1, 2], [[0, 0, 0], - [0, 1, 3], [1, 2, 1], - [0, 0, 0]] [1, 3, 0]] - - If a dimension of the expected spatial size is bigger than the input image size, - will not crop that dimension. So the cropped result may be smaller than expected size, and the cropped - results of several images may not have exactly same shape. - - Args: - keys: keys of the corresponding items to be transformed. - See also: :py:class:`monai.transforms.compose.MapTransform` - label_key: name of key for label image, this will be used for finding indices of every class. - spatial_size: the spatial size of the crop region e.g. [224, 224, 128]. - if a dimension of ROI size is bigger than image size, will not crop that dimension of the image. - if its components have non-positive values, the corresponding size of `label` will be used. - for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`, - the spatial size of output data will be [32, 40, 40]. - ratios: specified ratios of every class in the label to generate crop centers, including background class. - if None, every class will have the same ratio to generate crop centers. - num_classes: number of classes for argmax label, not necessary for One-Hot label. - num_samples: number of samples (crop regions) to take in each list. - image_key: if image_key is not None, only return the indices of every class that are within the valid - region of the image (``image > image_threshold``). - image_threshold: if enabled `image_key`, use ``image > image_threshold`` to - determine the valid image content area and select class indices only in this area. - indices_key: if provided pre-computed indices of every class, will ignore above `image` and - `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array - of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first - and cache the results for better performance. - meta_keys: explicitly indicate the key of the corresponding meta data dictionary. - used to add `patch_index` to the meta dict. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according - to the key data, default is `meta_dict`, the meta data is a dictionary object. - used to add `patch_index` to the meta dict. - allow_missing_keys: don't raise exception if key is missing. - - """ - - def __init__( - self, - keys: KeysCollection, - label_key: str, - spatial_size: Union[Sequence[int], int], - ratios: Optional[List[Union[float, int]]] = None, - num_classes: Optional[int] = None, - num_samples: int = 1, - image_key: Optional[str] = None, - image_threshold: float = 0.0, - indices_key: Optional[str] = None, - meta_keys: Optional[KeysCollection] = None, - meta_key_postfix: str = "meta_dict", - allow_missing_keys: bool = False, - ) -> None: - MapTransform.__init__(self, keys, allow_missing_keys) - self.label_key = label_key - self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size - self.ratios = ratios - self.num_classes = num_classes - self.num_samples = num_samples - self.image_key = image_key - self.image_threshold = image_threshold - self.indices_key = indices_key - self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) - if len(self.keys) != len(self.meta_keys): - raise ValueError("meta_keys should have the same length as keys.") - self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self.centers: Optional[List[List[np.ndarray]]] = None - - def randomize( - self, - label: np.ndarray, - indices: Optional[List[np.ndarray]] = None, - image: Optional[np.ndarray] = None, - ) -> None: - self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: List[np.ndarray] - if indices is None: - indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) - else: - indices_ = indices - self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R - ) - - def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarray]]: - d = dict(data) - label = d[self.label_key] - image = d[self.image_key] if self.image_key else None - indices = d.get(self.indices_key) if self.indices_key is not None else None - - self.randomize(label, indices, image) - if not isinstance(self.spatial_size, tuple): - raise ValueError("spatial_size must be a valid tuple.") - if self.centers is None: - raise ValueError("no available ROI centers to crop.") - - # initialize returned list with shallow copy to preserve key ordering - results: List[Dict[Hashable, np.ndarray]] = [dict(data) for _ in range(self.num_samples)] - - for i, center in enumerate(self.centers): - # fill in the extra keys with unmodified data - for key in set(data.keys()).difference(set(self.keys)): - results[i][key] = deepcopy(data[key]) - for key in self.key_iterator(d): - img = d[key] - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - orig_size = img.shape[1:] - results[i][key] = cropper(img) - self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) - # add `patch_index` to the meta data - for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): - meta_key = meta_key or f"{key}_{meta_key_postfix}" - if meta_key not in results[i]: - results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i - - return results - - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: - d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) - current_size = np.asarray(d[key].shape[1:]) - center = transform[InverseKeys.EXTRA_INFO]["center"] - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - # get required pad to start and end - pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) - pad_to_end = orig_size - current_size - pad_to_start - # interleave mins and maxes - pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) - inverse_transform = BorderPad(pad) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - - return d - - class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ResizeWithPadOrCrop`. From 31885ea1de784c7c72bf81975413acdfce74da47 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 14 Jul 2021 12:37:28 +0100 Subject: [PATCH 101/176] remove duplicate class Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 133 ------------------------------ 1 file changed, 133 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 2ab987bfef..9f9f880565 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1029,139 +1029,6 @@ def __call__( return results -class RandCropByLabelClasses(Randomizable, Transform): - """ - Crop random fixed sized regions with the center being a class based on the specified ratios of every class. - The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the - cropped images. For example, crop two (3 x 3) arrays from (5 x 5) array with `ratios=[1, 2, 3, 1]`:: - - image = np.array([ - [[0.0, 0.3, 0.4, 0.2, 0.0], - [0.0, 0.1, 0.2, 0.1, 0.4], - [0.0, 0.3, 0.5, 0.2, 0.0], - [0.1, 0.2, 0.1, 0.1, 0.0], - [0.0, 0.1, 0.2, 0.1, 0.0]] - ]) - label = np.array([ - [[0, 0, 0, 0, 0], - [0, 1, 2, 1, 0], - [0, 1, 3, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]] - ]) - cropper = RandCropByLabelClasses( - spatial_size=[3, 3], - ratios=[1, 2, 3, 1], - num_classes=4, - num_samples=2, - ) - label_samples = cropper(img=label, label=label, image=image) - - The 2 randomly cropped samples of `label` can be: - [[0, 1, 2], [[0, 0, 0], - [0, 1, 3], [1, 2, 1], - [0, 0, 0]] [1, 3, 0]] - - If a dimension of the expected spatial size is bigger than the input image size, - will not crop that dimension. So the cropped result may be smaller than expected size, and the cropped - results of several images may not have exactly same shape. - - Args: - spatial_size: the spatial size of the crop region e.g. [224, 224, 128]. - if a dimension of ROI size is bigger than image size, will not crop that dimension of the image. - if its components have non-positive values, the corresponding size of `label` will be used. - for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`, - the spatial size of output data will be [32, 40, 40]. - ratios: specified ratios of every class in the label to generate crop centers, including background class. - if None, every class will have the same ratio to generate crop centers. - label: the label image that is used for finding every classes, if None, must set at `self.__call__`. - num_classes: number of classes for argmax label, not necessary for One-Hot label. - num_samples: number of samples (crop regions) to take in each list. - image: if image is not None, only return the indices of every class that are within the valid - region of the image (``image > image_threshold``). - image_threshold: if enabled `image`, use ``image > image_threshold`` to - determine the valid image content area and select class indices only in this area. - indices: if provided pre-computed indices of every class, will ignore above `image` and - `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array - of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first - and cache the results for better performance. - - """ - - def __init__( - self, - spatial_size: Union[Sequence[int], int], - ratios: Optional[List[Union[float, int]]] = None, - label: Optional[np.ndarray] = None, - num_classes: Optional[int] = None, - num_samples: int = 1, - image: Optional[np.ndarray] = None, - image_threshold: float = 0.0, - indices: Optional[List[np.ndarray]] = None, - ) -> None: - self.spatial_size = ensure_tuple(spatial_size) - self.ratios = ratios - self.label = label - self.num_classes = num_classes - self.num_samples = num_samples - self.image = image - self.image_threshold = image_threshold - self.centers: Optional[List[List[np.ndarray]]] = None - self.indices = indices - - def randomize( - self, - label: np.ndarray, - indices: Optional[List[np.ndarray]] = None, - image: Optional[np.ndarray] = None, - ) -> None: - self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: List[np.ndarray] - if indices is None: - if self.indices is not None: - indices_ = self.indices - else: - indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) - else: - indices_ = indices - self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R - ) - - def __call__( - self, - img: np.ndarray, - label: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, - indices: Optional[List[np.ndarray]] = None, - ) -> List[np.ndarray]: - """ - Args: - img: input data to crop samples from based on the ratios of every class, assumes `img` is a - channel-first array. - label: the label image that is used for finding indices of every class, if None, use `self.label`. - image: optional image data to help select valid area, can be same as `img` or another image array. - use ``image > image_threshold`` to select the centers only in valid region. if None, use `self.image`. - indices: list of indices for every class in the image, used to randomly select crop centers. - - """ - if label is None: - label = self.label - if label is None: - raise ValueError("label should be provided.") - if image is None: - image = self.image - - self.randomize(label, indices, image) - results: List[np.ndarray] = [] - if self.centers is not None: - for center in self.centers: - cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - results.append(cropper(img)) - - return results - - class ResizeWithPadOrCrop(Transform): """ Resize an image to a target spatial size by either centrally cropping the image or From 5071afeece621d703cbdab9cf9bf97da78eb30bb Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 14 Jul 2021 17:38:56 +0100 Subject: [PATCH 102/176] ClassesToIndices, RandCropByLabelClasses Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 41 +++-- monai/transforms/croppad/dictionary.py | 6 +- monai/transforms/utility/array.py | 13 +- monai/transforms/utils.py | 26 +-- tests/test_classes_to_indices.py | 107 +++++++----- tests/test_map_classes_to_indices.py | 189 ++++++++++++++-------- tests/test_rand_crop_by_label_classes.py | 109 +++++++------ tests/test_rand_crop_by_label_classesd.py | 89 +++++----- 8 files changed, 341 insertions(+), 239 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 9f9f880565..632e34c92f 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -896,7 +896,7 @@ def __call__( return results -class RandCropByLabelClasses(Randomizable, Transform): +class RandCropByLabelClasses(Randomizable, TorchOrNumpyTransform): """ Crop random fixed sized regions with the center being a class based on the specified ratios of every class. The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the @@ -959,12 +959,12 @@ def __init__( self, spatial_size: Union[Sequence[int], int], ratios: Optional[List[Union[float, int]]] = None, - label: Optional[np.ndarray] = None, + label: Optional[DataObjects.Images] = None, num_classes: Optional[int] = None, num_samples: int = 1, - image: Optional[np.ndarray] = None, + image: Optional[DataObjects.Images] = None, image_threshold: float = 0.0, - indices: Optional[List[np.ndarray]] = None, + indices: Optional[List[DataObjects.Images]] = None, ) -> None: self.spatial_size = ensure_tuple(spatial_size) self.ratios = ratios @@ -978,30 +978,29 @@ def __init__( def randomize( self, - label: np.ndarray, - indices: Optional[List[np.ndarray]] = None, - image: Optional[np.ndarray] = None, + label: DataObjects.Images, + indices: Optional[List[DataObjects.Images]] = None, + image: Optional[DataObjects.Images] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: List[np.ndarray] - if indices is None: - if self.indices is not None: - indices_ = self.indices - else: - indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) - else: + indices_: List[DataObjects.Images] + if indices is not None: indices_ = indices + elif self.indices is not None: + indices_ = self.indices + else: + indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) self.centers = generate_label_classes_crop_centers( self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R ) def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, - indices: Optional[List[np.ndarray]] = None, - ) -> List[np.ndarray]: + img: DataObjects.Images, + label: Optional[DataObjects.Images] = None, + image: Optional[DataObjects.Images] = None, + indices: Optional[List[DataObjects.Images]] = None, + ) -> List[DataObjects.Images]: """ Args: img: input data to crop samples from based on the ratios of every class, assumes `img` is a @@ -1020,11 +1019,11 @@ def __call__( image = self.image self.randomize(label, indices, image) - results: List[np.ndarray] = [] + results: List[DataObjects.Images] = [] if self.centers is not None: for center in self.centers: cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - results.append(cropper(img)) # type: ignore + results.append(cropper(img)) return results diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 89be3cf26e..6ee79dc9cf 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -1256,11 +1256,11 @@ def __init__( def randomize( self, label: np.ndarray, - indices: Optional[List[np.ndarray]] = None, + indices: Optional[List[DataObjects.Images]] = None, image: Optional[np.ndarray] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: List[np.ndarray] + indices_: List[DataObjects.Images] if indices is None: indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) else: @@ -1269,7 +1269,7 @@ def randomize( self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R ) - def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarray]]: + def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, DataObjects.Images]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index dee3647258..67da1fde52 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -741,7 +741,7 @@ def __call__( return fg_indices, bg_indices -class ClassesToIndices(Transform): +class ClassesToIndices(TorchOrNumpyTransform): def __init__( self, num_classes: Optional[int] = None, @@ -768,10 +768,10 @@ def __init__( def __call__( self, - label: np.ndarray, - image: Optional[np.ndarray] = None, + label: DataObjects.Images, + image: Optional[DataObjects.Images] = None, output_shape: Optional[Sequence[int]] = None, - ) -> List[np.ndarray]: + ) -> List[DataObjects.Images]: """ Args: label: input data to compute the indices of every class. @@ -784,7 +784,10 @@ def __call__( output_shape = self.output_shape indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) if output_shape is not None: - indices = [np.stack([np.unravel_index(i, output_shape) for i in array]) for array in indices] + indices = [ + np.stack([np.unravel_index(i.cpu() if isinstance(i, torch.Tensor) else i, output_shape) for i in array]) + for array in indices + ] return indices diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 9d38d37ca9..e9f397f7c7 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -283,11 +283,11 @@ def map_binary_to_indices( def map_classes_to_indices( - label: np.ndarray, + label: DataObjects.Images, num_classes: Optional[int] = None, - image: Optional[np.ndarray] = None, + image: Optional[DataObjects.Images] = None, image_threshold: float = 0.0, -) -> List[np.ndarray]: +) -> List[DataObjects.Images]: """ Filter out indices of every class of the input label data, return the indices after fattening. It can handle both One-Hot format label and Argmax format label, must provide `num_classes` for @@ -307,11 +307,11 @@ def map_classes_to_indices( determine the valid image content area and select class indices only in this area. """ - img_flat: Optional[np.ndarray] = None + img_flat: Optional[DataObjects.Images] = None if image is not None: - img_flat = np.any(image > image_threshold, axis=0).ravel() + img_flat = (image > image_threshold).any(axis=0).ravel() # type: ignore - indices: List[np.ndarray] = [] + indices: List[DataObjects.Images] = [] # assuming the first dimension is channel channels = len(label) @@ -322,9 +322,10 @@ def map_classes_to_indices( num_classes_ = num_classes for c in range(num_classes_): - label_flat = np.any(label[c : c + 1] if channels > 1 else label == c, axis=0).ravel() - label_flat = np.logical_and(img_flat, label_flat) if img_flat is not None else label_flat - indices.append(np.nonzero(label_flat)[0]) + label_flat = (label[c : c + 1] if channels > 1 else label == c).any(axis=0).ravel() # type: ignore + label_flat = img_flat & label_flat if img_flat is not None else label_flat + idx = label_flat.nonzero() + indices.append(idx[0] if isinstance(idx, tuple) else idx.squeeze(1)) return indices @@ -468,7 +469,7 @@ def generate_label_classes_crop_centers( spatial_size: Union[Sequence[int], int], num_samples: int, label_spatial_shape: Sequence[int], - indices: List[np.ndarray], + indices: List[DataObjects.Images], ratios: Optional[List[Union[float, int]]] = None, rand_state: Optional[np.random.RandomState] = None, ) -> List[List[np.ndarray]]: @@ -497,8 +498,6 @@ def generate_label_classes_crop_centers( if any([i < 0 for i in ratios_]): raise ValueError("ratios should not contain negative number.") - # ensure indices are numpy array - indices = [np.asarray(i) for i in indices] for i, array in enumerate(indices): if len(array) == 0: warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.") @@ -510,7 +509,8 @@ def generate_label_classes_crop_centers( # randomly select the indices of a class based on the ratios indices_to_use = indices[i] random_int = rand_state.randint(len(indices_to_use)) - center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) + idx, *_ = convert_data_type(indices_to_use[random_int], np.ndarray) + center = np.unravel_index(idx, label_spatial_shape) # shift center to range of valid centers center_ori = list(center) centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) diff --git a/tests/test_classes_to_indices.py b/tests/test_classes_to_indices.py index 0ba3dd094a..a38f21a93d 100644 --- a/tests/test_classes_to_indices.py +++ b/tests/test_classes_to_indices.py @@ -12,66 +12,89 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ClassesToIndices +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - # test Argmax data - {"num_classes": 3, "image_threshold": 0.0}, - np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - None, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + # test Argmax data + {"num_classes": 3, "image_threshold": 0.0}, + p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + None, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], + ] + ) -TEST_CASE_2 = [ - {"num_classes": 3, "image_threshold": 60}, - np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS.append( + [ + {"num_classes": 3, "image_threshold": 60}, + p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], + ] + ) -TEST_CASE_3 = [ - # test One-Hot data - {"image_threshold": 0.0}, - np.array( + TESTS.append( [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + # test One-Hot data + {"image_threshold": 0.0}, + p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + None, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], ] - ), - None, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] + ) -TEST_CASE_4 = [ - {"num_classes": None, "image_threshold": 60}, - np.array( + TESTS.append( [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + {"num_classes": None, "image_threshold": 60}, + p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], ] - ), - np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + ) -TEST_CASE_5 = [ - # test output_shape - {"num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, - np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - None, - [np.array([[0, 0], [1, 1], [2, 2]]), np.array([[0, 1], [1, 2], [2, 0]]), np.array([[0, 2], [1, 0], [2, 1]])], -] + TESTS.append( + [ + # test output_shape + {"num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, + p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + None, + [ + np.array([[0, 0], [1, 1], [2, 2]]), + np.array([[0, 1], [1, 2], [2, 0]]), + np.array([[0, 2], [1, 0], [2, 1]]), + ], + ] + ) class TestClassesToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, input_args, label, image, expected_indices): indices = ClassesToIndices(**input_args)(label, image) for i, e in zip(indices, expected_indices): + i = i.cpu() if isinstance(i, torch.Tensor) else i np.testing.assert_allclose(i, e) diff --git a/tests/test_map_classes_to_indices.py b/tests/test_map_classes_to_indices.py index 2320954520..e27a7e540d 100644 --- a/tests/test_map_classes_to_indices.py +++ b/tests/test_map_classes_to_indices.py @@ -12,88 +12,149 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import map_classes_to_indices +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - # test Argmax data - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 3, "image": None, "image_threshold": 0.0}, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + # test Argmax data + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 3, + "image": None, + "image_threshold": 0.0, + }, + [ + np.array([0, 4, 8]), + np.array([1, 5, 6]), + np.array([2, 3, 7]), + ], + ] + ) -TEST_CASE_2 = [ - { - "label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - "num_classes": 3, - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - "image_threshold": 60, - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 3, + "image": p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + "image_threshold": 60, + }, + [ + np.array([0, 8]), + np.array([1, 5, 6]), + np.array([3]), + ], + ] + ) -TEST_CASE_3 = [ - # test One-Hot data - { - "label": np.array( + TESTS.append( + [ + # test One-Hot data + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + "image": None, + "image_threshold": 0.0, + }, [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ), - "image": None, - "image_threshold": 0.0, - }, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] + np.array([0, 4, 8]), + np.array([1, 5, 6]), + np.array([2, 3, 7]), + ], + ] + ) -TEST_CASE_4 = [ - { - "label": np.array( + TESTS.append( + [ + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + "num_classes": None, + "image": p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + "image_threshold": 60, + }, [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ), - "num_classes": None, - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - "image_threshold": 60, - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + np.array([0, 8]), + np.array([1, 5, 6]), + np.array([3]), + ], + ] + ) -TEST_CASE_5 = [ - # test empty class - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 5, "image": None, "image_threshold": 0.0}, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], -] + TESTS.append( + [ + # test empty class + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 5, + "image": None, + "image_threshold": 0.0, + }, + [ + np.array([0, 4, 8]), + np.array([1, 5, 6]), + np.array([2, 3, 7]), + np.array([]), + np.array([]), + ], + ] + ) -TEST_CASE_6 = [ - # test empty class - { - "label": np.array( + TESTS.append( + [ + # test empty class + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + ] + ) + ), + "image": None, + "image_threshold": 0.0, + }, [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - ] - ), - "image": None, - "image_threshold": 0.0, - }, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], -] + np.array([0, 4, 8]), + np.array([1, 5, 6]), + np.array([2, 3, 7]), + np.array([]), + np.array([]), + ], + ] + ) class TestMapClassesToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_data, expected_indices): indices = map_classes_to_indices(**input_data) for i, e in zip(indices, expected_indices): + i = i.cpu() if isinstance(i, torch.Tensor) else i np.testing.assert_allclose(i, e) diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index b21f971042..d562a44a6d 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -15,68 +15,77 @@ from parameterized import parameterized from monai.transforms import ClassesToIndices, RandCropByLabelClasses +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ +TESTS_INDICES, TESTS_SHAPE = [], [] +for p in TEST_NDARRAYS: # One-Hot label - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "num_classes": None, - "spatial_size": [2, 2, -1], - "ratios": [1, 1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 3), -] + TESTS_INDICES.append( + [ + { + "label": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + list, + (3, 2, 2, 3), + ] + ) -TEST_CASE_1 = [ - # Argmax label - { - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 2), -] + TESTS_INDICES.append( + [ + # Argmax label + { + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))}, + list, + (3, 2, 2, 2), + ] + ) -TEST_CASE_2 = [ - # provide label at runtime - { - "label": None, - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] + TESTS_SHAPE.append( + [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 2, 2, 2), + ] + ) class TestRandCropByLabelClasses(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS_INDICES + TESTS_SHAPE) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByLabelClasses(**input_param)(**input_data) self.assertIsInstance(result, expected_type) self.assertTupleEqual(result[0].shape, expected_shape) - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TESTS_INDICES) def test_indices(self, input_param, input_data, expected_type, expected_shape): input_param["indices"] = ClassesToIndices(num_classes=input_param["num_classes"])(input_param["label"]) result = RandCropByLabelClasses(**input_param)(**input_data) diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index 829096953b..27fe3425dd 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -15,52 +15,59 @@ from parameterized import parameterized from monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - # One-Hot label - { - "keys": "img", - "label_key": "label", - "num_classes": None, - "spatial_size": [2, 2, -1], - "ratios": [1, 1, 1], - "num_samples": 2, - "image_key": "image", - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 3), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + # One-Hot label + { + "keys": "img", + "label_key": "label", + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 2, 2, 3), + ] + ) -TEST_CASE_1 = [ - # Argmax label - { - "keys": "img", - "label_key": "label", - "num_classes": 2, - "spatial_size": [2, 2, 2], - "ratios": [1, 1], - "num_samples": 2, - "image_key": "image", - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] + TESTS.append( + [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + }, + list, + (3, 2, 2, 2), + ] + ) class TestRandCropByLabelClassesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + @parameterized.expand(TESTS) def test_type_shape(self, input_param, input_data, expected_type, expected_shape): result = RandCropByLabelClassesd(**input_param)(input_data) self.assertIsInstance(result, expected_type) From 33e081802e07b73c945e4108377e770e8145c8c1 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 14 Jul 2021 17:42:13 +0100 Subject: [PATCH 103/176] post merge fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_divisible_pad.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index 4ccb31b25b..923370caca 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -24,6 +24,7 @@ # pad first dim to be divisible by 7, the second unchanged. TESTS.append( [ + {"k": (7, -1), "mode": "constant"}, p(np.zeros((3, 8, 7))), p(np.zeros((3, 14, 7))), ] From 7b84d43b81faa768f92ba7a1253771d99774e2f8 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 15 Jul 2021 10:05:53 +0100 Subject: [PATCH 104/176] can't pickle torch.Generator Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 18 +++++++-------- monai/transforms/transform.py | 3 --- tests/test_rand_rician_noise.py | 33 ++++++++-------------------- tests/test_rand_rician_noised.py | 34 +++++++++-------------------- 4 files changed, 28 insertions(+), 60 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index d7245760a2..ec0d1cc0da 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -143,17 +143,17 @@ def __init__( self._noise2: DataObjects.Images def _add_noise(self, img: DataObjects.Images, mean: float, std: float): + dtype_np = dtype_convert(img.dtype, np.ndarray) im_shape = img.shape + _std = self.R.uniform(0, std) if self.sample_std else std + self._noise1 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np) + self._noise2 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np) if isinstance(img, torch.Tensor): - _std = float(torch.rand(1, generator=self.R_torch)) * std if self.sample_std else std - self._noise1 = torch.normal(mean, _std, size=im_shape, generator=self.R_torch).to(img.dtype).to(img.device) - self._noise2 = torch.normal(mean, _std, size=im_shape, generator=self.R_torch).to(img.dtype).to(img.device) - return torch.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) - else: - _std = self.R.uniform(0, std) if self.sample_std else std - self._noise1 = self.R.normal(mean, _std, size=im_shape).astype(img.dtype) - self._noise2 = self.R.normal(mean, _std, size=im_shape).astype(img.dtype) - return np.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) # type: ignore + n1 = torch.tensor(self._noise1, device=img.device) + n2 = torch.tensor(self._noise2, device=img.device) + return torch.sqrt((img + n1) ** 2 + n2 ** 2) + + return np.sqrt((img + self._noise1) ** 2 + self._noise2 ** 2) # type: ignore def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 3ff12b7522..76429d40b5 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -193,7 +193,6 @@ class Randomizable(ABC, ThreadUnsafe): """ R: np.random.RandomState = np.random.RandomState() - R_torch: Optional[torch._C.Generator] = None def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -218,7 +217,6 @@ def set_random_state( _seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed _seed = _seed % MAX_SEED self.R = np.random.RandomState(_seed) - self.R_torch = torch.Generator().manual_seed(int(_seed)) return self if state is not None: @@ -228,7 +226,6 @@ def set_random_state( return self self.R = np.random.RandomState() - self.R_torch = None return self def randomize(self, data: Any) -> None: diff --git a/tests/test_rand_rician_noise.py b/tests/test_rand_rician_noise.py index 5330c72006..7ec5fc4dc4 100644 --- a/tests/test_rand_rician_noise.py +++ b/tests/test_rand_rician_noise.py @@ -16,21 +16,21 @@ from parameterized import parameterized from monai.transforms import RandRicianNoise -from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D -TESTS = [ - ("test_zero_mean", 0, 0.1), - ("test_non_zero_mean", 1, 0.5), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append(("test_zero_mean", p, 0, 0.1)) + TESTS.append(("test_non_zero_mean", p, 1, 0.5)) class TestRandRicianNoise(NumpyImageTestCase2D): @parameterized.expand(TESTS) - def test_correct_results(self, _, mean, std): + def test_correct_results(self, _, in_type, mean, std): seed = 0 rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std) rician_fn.set_random_state(seed) - noised = rician_fn(self.imt) + noised = rician_fn(in_type(self.imt)) np.random.seed(seed) np.random.random() _std = np.random.uniform(0, std) @@ -38,25 +38,10 @@ def test_correct_results(self, _, mean, std): (self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2 + np.random.normal(mean, _std, size=self.imt.shape) ** 2 ) + if isinstance(noised, torch.Tensor): + noised = noised.cpu() np.testing.assert_allclose(expected, noised, atol=1e-5) -class TestRandRicianNoiseTorch(TorchImageTestCase2D): - @parameterized.expand(TESTS) - def test_correct_results(self, _, mean, std): - seed = 0 - for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: - rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std) - rician_fn.set_random_state(seed) - noised = rician_fn(self.imt.to(device)) - torch.manual_seed(seed) - _std = float(torch.rand(1)) * std - expected = torch.sqrt( - (self.imt + torch.normal(mean, _std, size=self.imt.shape)) ** 2 - + torch.normal(mean, _std, size=self.imt.shape) ** 2 - ).to(device) - torch.testing.assert_allclose(expected, noised, rtol=1e-7, atol=1e-5) - - if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_rician_noised.py b/tests/test_rand_rician_noised.py index cc5d1c4a84..666edc3d3b 100644 --- a/tests/test_rand_rician_noised.py +++ b/tests/test_rand_rician_noised.py @@ -16,22 +16,23 @@ from parameterized import parameterized from monai.transforms import RandRicianNoised -from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D -TEST_CASE_0 = ["test_zero_mean", ["img1", "img2"], 0, 0.1] -TEST_CASE_1 = ["test_non_zero_mean", ["img1", "img2"], 1, 0.5] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1]) + TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5]) seed = 0 # Test with numpy class TestRandRicianNoisedNumpy(NumpyImageTestCase2D): - @parameterized.expand(TEST_CASES) - def test_correct_results(self, _, keys, mean, std): + @parameterized.expand(TESTS) + def test_correct_results(self, _, in_type, keys, mean, std): rician_fn = RandRicianNoised(keys=keys, global_prob=1.0, prob=1.0, mean=mean, std=std) rician_fn.set_random_state(seed) - noised = rician_fn({k: self.imt for k in keys}) + noised = rician_fn({k: in_type(self.imt) for k in keys}) np.random.seed(seed) for k in keys: np.random.random() @@ -40,23 +41,8 @@ def test_correct_results(self, _, keys, mean, std): (self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2 + np.random.normal(mean, _std, size=self.imt.shape) ** 2 ) - np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) - - -# Test with torch -class TestRandRicianNoisedTorch(TorchImageTestCase2D): - @parameterized.expand(TEST_CASES) - def test_correct_results(self, _, keys, mean, std): - rician_fn = RandRicianNoised(keys=keys, global_prob=1.0, prob=1.0, mean=mean, std=std) - rician_fn.set_random_state(seed) - noised = rician_fn({k: self.imt for k in keys}) - torch.manual_seed(seed) - for k in keys: - _std = float(torch.rand(1)) * std - expected = torch.sqrt( - (self.imt + torch.normal(mean, _std, size=self.imt.shape)) ** 2 - + torch.normal(mean, _std, size=self.imt.shape) ** 2 - ) + if isinstance(noised[k], torch.Tensor): + noised[k] = noised[k].cpu() np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) From b9d0a243616507e949dcddf0fc90e583c16b58e1 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 15 Jul 2021 10:28:10 +0100 Subject: [PATCH 105/176] support pytorch==1.6 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e9f397f7c7..5230bbed8f 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -322,7 +322,9 @@ def map_classes_to_indices( num_classes_ = num_classes for c in range(num_classes_): - label_flat = (label[c : c + 1] if channels > 1 else label == c).any(axis=0).ravel() # type: ignore + label_flat = (label[c : c + 1] if channels > 1 else label == c).any(axis=0) # type: ignore + # ravel with support for pytorch1.6 (flatten) + label_flat = label_flat.ravel() if hasattr(label_flat, "ravel") else label_flat.flatten() label_flat = img_flat & label_flat if img_flat is not None else label_flat idx = label_flat.nonzero() indices.append(idx[0] if isinstance(idx, tuple) else idx.squeeze(1)) From c687d5f5aaac54bf95de83d603afc24cf369a780 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 15 Jul 2021 15:41:43 +0100 Subject: [PATCH 106/176] spacing/spacingd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/nifti_writer.py | 6 +- monai/data/utils.py | 48 ++-- monai/transforms/__init__.py | 1 - monai/transforms/croppad/array.py | 3 +- monai/transforms/croppad/dictionary.py | 3 +- monai/transforms/intensity/array.py | 10 +- monai/transforms/io/array.py | 3 +- monai/transforms/spatial/array.py | 58 ++--- monai/transforms/spatial/dictionary.py | 6 +- monai/transforms/transform.py | 51 +--- monai/transforms/utility/array.py | 11 +- monai/transforms/utils.py | 3 +- monai/utils/__init__.py | 1 + monai/utils/misc.py | 59 +++++ tests/test_convert_data_type.py | 3 +- tests/test_dtype_convert.py | 2 +- tests/test_load_spacing_orientation.py | 3 +- tests/test_spacing.py | 311 ++++++++++++++----------- tests/test_spacingd.py | 143 ++++++------ tests/test_to_affine_nd.py | 37 +++ tests/test_zoom_affine.py | 126 ++++++---- 21 files changed, 515 insertions(+), 373 deletions(-) create mode 100644 tests/test_to_affine_nd.py diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index 96076ba94c..9af4d396ef 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -17,9 +17,9 @@ from monai.config import DtypeLike from monai.data.utils import compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform -from monai.transforms.transform import convert_data_type from monai.utils import GridSampleMode, GridSamplePadMode, optional_import from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type nib, _ = optional_import("nibabel") @@ -108,11 +108,11 @@ def write_nifti( sr = min(data_np.ndim, 3) if affine is None: affine = np.eye(4, dtype=np.float64) - affine = to_affine_nd(sr, affine) + affine = to_affine_nd(sr, affine) # type: ignore if target_affine is None: target_affine = affine - target_affine = to_affine_nd(sr, target_affine) + target_affine = to_affine_nd(sr, target_affine) # type: ignore if np.allclose(affine, target_affine, atol=1e-3): # no affine changes, save (data, affine) diff --git a/monai/data/utils.py b/monai/data/utils.py index 4c249db3d1..21c76c1cf1 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -39,7 +39,8 @@ issequenceiterable, optional_import, ) -from monai.utils.enums import Method +from monai.utils.enums import DataObjects, Method +from monai.utils.misc import convert_data_type pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame") @@ -542,7 +543,7 @@ def rectify_header_sform_qform(img_nii): return img_nii -def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = True): +def zoom_affine(affine: DataObjects.Images, scale: Sequence[float], diagonal: bool = True): """ To make column norm of `affine` the same as `scale`. If diagonal is False, returns an affine that combines orthogonal rotation and the new scale. @@ -567,8 +568,7 @@ def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = Tru the updated `n x n` affine. """ - - affine = np.array(affine, dtype=float, copy=True) + affine, *_ = convert_data_type(deepcopy(affine), np.ndarray, dtype=float) if len(affine) != len(affine[0]): raise ValueError(f"affine must be n x n, got {len(affine)} x {len(affine[0])}.") scale_np = np.array(scale, dtype=float, copy=True) @@ -583,14 +583,15 @@ def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = Tru scale_np[scale_np == 0] = 1.0 if diagonal: - return np.diag(np.append(scale_np, [1.0])) - rzs = affine[:-1, :-1] # rotation zoom scale - zs = np.linalg.cholesky(rzs.T @ rzs).T - rotation = rzs @ np.linalg.inv(zs) - s = np.sign(np.diag(zs)) * np.abs(scale_np) - # construct new affine with rotation and zoom - new_affine = np.eye(len(affine)) - new_affine[:-1, :-1] = rotation @ np.diag(s) + new_affine = np.diag(np.append(scale_np, [1.0])) + else: + rzs = affine[:-1, :-1] # rotation zoom scale + zs = np.linalg.cholesky(rzs.T @ rzs).T + rotation = rzs @ np.linalg.inv(zs) + s = np.sign(np.diag(zs)) * np.abs(scale_np) + # construct new affine with rotation and zoom + new_affine = np.eye(len(affine)) + new_affine[:-1, :-1] = rotation @ np.diag(s) return new_affine @@ -610,8 +611,8 @@ def compute_shape_offset( """ shape = np.array(spatial_shape, copy=True, dtype=float) sr = len(shape) - in_affine = to_affine_nd(sr, in_affine) - out_affine = to_affine_nd(sr, out_affine) + in_affine = to_affine_nd(sr, in_affine) # type: ignore + out_affine = to_affine_nd(sr, out_affine) # type: ignore in_coords = [(0.0, dim - 1.0) for dim in shape] corners = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1)) corners = np.concatenate((corners, np.ones_like(corners[:1]))) @@ -630,7 +631,7 @@ def compute_shape_offset( return out_shape.astype(int), offset -def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: +def to_affine_nd(r: Union[DataObjects.Images, int], affine: DataObjects.Images) -> DataObjects.Images: """ Using elements from affine, to create a new affine matrix by assigning the rotation/zoom/scaling matrix and the translation vector. @@ -657,19 +658,20 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: an (r+1) x (r+1) matrix """ - affine_np = np.array(affine, dtype=np.float64) - if affine_np.ndim != 2: - raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.") - new_affine = np.array(r, dtype=np.float64, copy=True) + if affine.ndim != 2: + raise ValueError(f"affine must have 2 dimensions, got {affine.ndim}.") + device = affine.device if isinstance(affine, torch.Tensor) else None + new_affine, *_ = convert_data_type(deepcopy(r), type(affine), dtype=np.float64, device=device) if new_affine.ndim == 0: - sr: int = int(new_affine.astype(np.uint)) + sr: int = int(new_affine) if not np.isfinite(sr) or sr < 0: raise ValueError(f"r must be positive, got {sr}.") new_affine = np.eye(sr + 1, dtype=np.float64) - d = max(min(len(new_affine) - 1, len(affine_np) - 1), 1) - new_affine[:d, :d] = affine_np[:d, :d] + new_affine, *_ = convert_data_type(new_affine, type(affine), device=device) + d = max(min(len(new_affine) - 1, len(affine) - 1), 1) + new_affine[:d, :d] = affine[:d, :d] if d > 1: - new_affine[:d, -1] = affine_np[:d, -1] + new_affine[:d, -1] = affine[:d, -1] return new_affine diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index f25b3f0f0b..7ac019a913 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -311,7 +311,6 @@ ThreadUnsafe, Transform, apply_transform, - convert_data_type, ) from .utility.array import ( AddChannel, diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index a345e04121..fea46d19e3 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -23,7 +23,7 @@ from monai.config import IndexSelection from monai.data.utils import get_random_patch, get_valid_patch_size -from monai.transforms.transform import NumpyTransform, Randomizable, TorchOrNumpyTransform, Transform, convert_data_type +from monai.transforms.transform import NumpyTransform, Randomizable, TorchOrNumpyTransform, Transform from monai.transforms.utils import ( compute_divisible_spatial_size, generate_label_classes_crop_centers, @@ -36,6 +36,7 @@ ) from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type __all__ = [ "Pad", diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index dbadbcdea9..56f78d1ac4 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -37,7 +37,7 @@ SpatialPad, ) from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform, Randomizable, convert_data_type +from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import ( allow_missing_keys_mode, generate_label_classes_crop_centers, @@ -50,6 +50,7 @@ from monai.utils import ImageMetaKey as Key from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple from monai.utils.enums import DataObjects, InverseKeys +from monai.utils.misc import convert_data_type __all__ = [ "NumpyPadModeSequence", diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index ec0d1cc0da..4a4e01345f 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -24,17 +24,11 @@ from monai.config import DtypeLike from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter -from monai.transforms.transform import ( - RandomizableTransform, - TorchOrNumpyTransform, - TorchTransform, - Transform, - convert_data_type, -) +from monai.transforms.transform import RandomizableTransform, TorchOrNumpyTransform, TorchTransform, Transform from monai.transforms.utils import rescale_array from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, ensure_tuple, ensure_tuple_rep, ensure_tuple_size from monai.utils.enums import DataObjects -from monai.utils.misc import dtype_convert +from monai.utils.misc import convert_data_type, dtype_convert __all__ = [ "RandGaussianNoise", diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 1a9ccfc914..453c709f35 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -23,11 +23,12 @@ from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.data.nifti_saver import NiftiSaver from monai.data.png_saver import PNGSaver -from monai.transforms.transform import Transform, convert_data_type +from monai.transforms.transform import Transform from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode, ensure_tuple, optional_import from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 1cf3474837..e8e6c51eb6 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -14,6 +14,7 @@ """ import warnings +from copy import deepcopy from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np @@ -31,7 +32,6 @@ TorchOrNumpyTransform, TorchTransform, Transform, - convert_data_type, ) from monai.transforms.utils import ( create_control_grid, @@ -55,6 +55,7 @@ optional_import, ) from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type nib, _ = optional_import("nibabel") @@ -85,7 +86,7 @@ RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] -class Spacing(ToDoTransform): +class Spacing(TorchTransform): """ Resample input image into the specified `pixdim`. """ @@ -141,14 +142,14 @@ def __init__( def __call__( self, - data_array: np.ndarray, - affine: Optional[np.ndarray] = None, + data_array: DataObjects.Images, + affine: Optional[DataObjects.Images] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, output_spatial_shape: Optional[np.ndarray] = None, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Tuple[DataObjects.Images, DataObjects.Images, DataObjects.Images]: """ Args: data_array: in shape (num_channels, H[, W, ...]). @@ -177,7 +178,10 @@ def __call__( """ _dtype = dtype or self.dtype or data_array.dtype - sr = data_array.ndim - 1 + data_array_torch: torch.Tensor + data_array_torch, orig_type, orig_device = convert_data_type(data_array, torch.Tensor, dtype=_dtype) # type: ignore + + sr = data_array_torch.ndim - 1 if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") if affine is None: @@ -186,6 +190,7 @@ def __call__( affine_ = np.eye(sr + 1, dtype=np.float64) else: affine_ = to_affine_nd(sr, affine) + affine_, *_ = convert_data_type(affine_, np.ndarray) out_d = self.pixdim[:sr] if out_d.size < sr: @@ -193,34 +198,35 @@ def __call__( # compute output affine, shape and offset new_affine = zoom_affine(affine_, out_d, diagonal=self.diagonal) - output_shape, offset = compute_shape_offset(data_array.shape[1:], affine_, new_affine) + output_shape, offset = compute_shape_offset(data_array_torch.shape[1:], affine_, new_affine) new_affine[:sr, -1] = offset[:sr] - transform = np.linalg.inv(affine_) @ new_affine + affine_inv = np.linalg.inv(affine_) + transform = affine_inv @ new_affine # adapt to the actual rank transform = to_affine_nd(sr, transform) # no resampling if it's identity transform if np.allclose(transform, np.diag(np.ones(len(transform))), atol=1e-3): - output_data = data_array.copy().astype(np.float32) + output_data, *_ = convert_data_type(deepcopy(data_array), dtype=_dtype) new_affine = to_affine_nd(affine, new_affine) - return output_data, affine, new_affine - # resample - affine_xform = AffineTransform( - normalized=False, - mode=mode or self.mode, - padding_mode=padding_mode or self.padding_mode, - align_corners=self.align_corners if align_corners is None else align_corners, - reverse_indexing=True, - ) - output_data = affine_xform( - # AffineTransform requires a batch dim - torch.as_tensor(np.ascontiguousarray(data_array).astype(_dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)), - spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape, - ) - output_data = np.asarray(output_data.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # type: ignore - new_affine = to_affine_nd(affine, new_affine) + else: + # resample + affine_xform = AffineTransform( + normalized=False, + mode=mode or self.mode, + padding_mode=padding_mode or self.padding_mode, + align_corners=self.align_corners if align_corners is None else align_corners, + reverse_indexing=True, + ) + output_data = affine_xform( + # AffineTransform requires a batch dim + data_array_torch.unsqueeze(0), + convert_data_type(transform, torch.Tensor, data_array_torch.device, dtype=_dtype)[0], + spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape, + ).squeeze(0) + output_data, *_ = convert_data_type(output_data, orig_type, dtype=np.float32) # type: ignore + new_affine = to_affine_nd(affine, new_affine) return output_data, affine, new_affine diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 46bcae6891..450bba0146 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -206,9 +206,7 @@ def __init__( raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - def __call__( - self, data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]] - ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: + def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d: Dict = dict(data) for key, mode, padding_mode, align_corners, dtype, meta_key, meta_key_postfix in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.meta_keys, self.meta_key_postfix @@ -222,7 +220,7 @@ def __call__( # using affine fetched from d[affine_key] original_spatial_shape = d[key].shape[1:] d[key], old_affine, new_affine = self.spacing_transform( - data_array=np.asarray(d[key]), + data_array=d[key], affine=meta_data["affine"], mode=mode, padding_mode=padding_mode, diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 76429d40b5..c2519d4165 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -21,10 +21,9 @@ from monai import transforms from monai.config import KeysCollection -from monai.config.type_definitions import DtypeLike from monai.utils import MAX_SEED, ensure_tuple from monai.utils.enums import DataObjects -from monai.utils.misc import dtype_convert +from monai.utils.misc import convert_data_type __all__ = [ "ThreadUnsafe", @@ -33,60 +32,12 @@ "RandomizableTransform", "Transform", "MapTransform", - "convert_data_type", "NumpyTransform", ] ReturnType = TypeVar("ReturnType") -def convert_data_type( - data: DataObjects.Images, - output_type: Optional[type] = None, - device: Optional[torch.device] = None, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None, -) -> Tuple[DataObjects.Images, type, Optional[torch.device]]: - """Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc. - - Args: - data: data to be converted - output_type: `torch.Tensor` or `np.ndarray` (if blank, unchanged) - device: if output is `torch.Tensor`, select device (if blank, unchanged) - dtype: dtype of output data. Converted to correct library type (e.g., - `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`). - If left blank, it remains unchanged. - - Returns: - modified data, orig_type, orig_device - """ - orig_type = type(data) - orig_device = data.device if isinstance(data, torch.Tensor) else None - - output_type = output_type or orig_type - # objects like float don't have dtype, so return their type - dtype = dtype_convert(dtype or data.dtype, output_type) if hasattr(data, "dtype") else type(data) - - if output_type is torch.Tensor: - if orig_type is np.ndarray: - data = torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data)) - else: - data = torch.as_tensor(data) - if dtype != data.dtype: - data = data.to(dtype) # type: ignore - elif output_type is np.ndarray: - if orig_type is torch.Tensor: - data = data.detach().cpu().numpy() # type: ignore - else: - data = np.array(data) - if dtype != data.dtype: - data = data.astype(dtype) # type: ignore - - if isinstance(data, torch.Tensor) and device is not None: - data = data.to(device) - - return data, orig_type, orig_device - - def _apply_transform( transform: Callable[..., ReturnType], parameters: Any, unpack_parameters: bool = False ) -> ReturnType: diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 67da1fde52..11f134a624 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -23,14 +23,7 @@ import torch from monai.config import DtypeLike -from monai.transforms.transform import ( - NumpyTransform, - Randomizable, - TorchOrNumpyTransform, - TorchTransform, - Transform, - convert_data_type, -) +from monai.transforms.transform import NumpyTransform, Randomizable, TorchOrNumpyTransform, TorchTransform, Transform from monai.transforms.utils import ( convert_to_numpy, convert_to_tensor, @@ -41,7 +34,7 @@ ) from monai.utils import ensure_tuple, min_version, optional_import from monai.utils.enums import DataObjects -from monai.utils.misc import dtype_convert +from monai.utils.misc import convert_data_type, dtype_convert PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 5230bbed8f..83199ee493 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -22,7 +22,7 @@ from monai.config import DtypeLike, IndexSelection from monai.networks.layers import GaussianFilter from monai.transforms.compose import Compose -from monai.transforms.transform import MapTransform, convert_data_type +from monai.transforms.transform import MapTransform from monai.utils import ( GridSampleMode, InterpolateMode, @@ -36,6 +36,7 @@ optional_import, ) from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type measure, _ = optional_import("skimage.measure", "0.14.2", min_version) cp, has_cp = optional_import("cupy") diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index c7dd1972c2..f99db70deb 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -40,6 +40,7 @@ from .misc import ( MAX_SEED, ImageMetaKey, + convert_data_type, copy_to_device, dtype_convert, dtype_numpy_to_torch, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index c6ca5fbe44..2a5a8bc0b3 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -22,6 +22,8 @@ import numpy as np import torch +from monai.config.type_definitions import DtypeLike +from monai.utils.enums import DataObjects from monai.utils.module import get_torch_version_tuple __all__ = [ @@ -45,6 +47,7 @@ "MAX_SEED", "copy_to_device", "ImageMetaKey", + "convert_data_type", ] _seed = None @@ -345,6 +348,62 @@ def dtype_convert(dtype, data_type): return dtype_torch_to_numpy(dtype) +def convert_data_type( + data: Any, + output_type: Optional[type] = None, + device: Optional[torch.device] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, +) -> Tuple[DataObjects.Images, type, Optional[torch.device]]: + """Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc. + + Args: + data: data to be converted + output_type: `torch.Tensor` or `np.ndarray` (if blank, unchanged) + device: if output is `torch.Tensor`, select device (if blank, unchanged) + dtype: dtype of output data. Converted to correct library type (e.g., + `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`). + If left blank, it remains unchanged. + + Returns: + modified data, orig_type, orig_device + """ + orig_type = type(data) + orig_device = data.device if isinstance(data, torch.Tensor) else None + + output_type = output_type or orig_type + + def get_dtype(data: Any): + if hasattr(data, "dtype"): + return data.dtype + # need recursion + if isinstance(data, Sequence): + return get_dtype(data[0]) + # objects like float don't have dtype, so return their type + return type(data) + + dtype = dtype_convert(dtype or get_dtype(data), output_type) + + if output_type is torch.Tensor: + if orig_type is np.ndarray: + data = torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data)) + else: + data = torch.as_tensor(data) + if dtype != data.dtype: + data = data.to(dtype) # type: ignore + elif output_type is np.ndarray: + if orig_type is torch.Tensor: + data = data.detach().cpu().numpy() # type: ignore + else: + data = np.array(data) + if dtype != data.dtype: + data = data.astype(dtype) # type: ignore + + if isinstance(data, torch.Tensor) and device is not None: + data = data.to(device) + + return data, orig_type, orig_device + + def copy_to_device( obj: Any, device: Optional[Union[str, torch.device]], diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py index 3a47e7d436..bf5179394a 100644 --- a/tests/test_convert_data_type.py +++ b/tests/test_convert_data_type.py @@ -16,8 +16,7 @@ import torch from parameterized import parameterized -from monai.transforms.transform import convert_data_type -from monai.utils.misc import dtype_convert +from monai.utils.misc import convert_data_type, dtype_convert TESTS: List[Tuple] = [] TESTS.append((np.array, torch.Tensor, np.float32, torch.float32)) diff --git a/tests/test_dtype_convert.py b/tests/test_dtype_convert.py index b07f58989e..90ae371a68 100644 --- a/tests/test_dtype_convert.py +++ b/tests/test_dtype_convert.py @@ -16,8 +16,8 @@ from parameterized import parameterized from monai.utils.misc import dtype_convert +from tests.utils import TEST_NDARRAYS -TEST_NDARRAYS = [torch.Tensor, np.ndarray] DTYPES = [torch.float32, np.float32, np.dtype(np.float32)] TESTS = [] diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index 5e92e8dd37..0317b48cba 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -18,7 +18,8 @@ from nibabel.processing import resample_to_output from parameterized import parameterized -from monai.transforms import AddChanneld, LoadImaged, Orientationd, Spacingd, convert_data_type +from monai.transforms import AddChanneld, LoadImaged, Orientationd, Spacingd +from monai.utils.misc import convert_data_type FILES = tuple( os.path.join(os.path.dirname(__file__), "testing_data", filename) diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 6be6730c5a..5b7d054a7c 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -12,155 +12,204 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import Spacing from monai.utils import ensure_tuple, fall_back_tuple +from tests.test_dtype_convert import TEST_NDARRAYS -TEST_CASES = [ - [ - {"pixdim": (1.0, 1.5), "padding_mode": "zeros", "dtype": float}, - np.arange(4).reshape((1, 2, 2)) + 1.0, # data - {"affine": np.eye(4)}, - np.array([[[1.0, 1.0], [3.0, 2.0]]]), - ], - [ - {"pixdim": 1.0, "padding_mode": "zeros", "dtype": float}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.eye(4)}, - np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), - ], - [ - {"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.eye(4)}, - np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), - ], - [ - {"pixdim": (1.0, 0.2, 1.5), "diagonal": False, "padding_mode": "zeros", "align_corners": True}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.array([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]])}, - np.array([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]), - ], - [ - {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, - np.arange(24).reshape((2, 3, 4)), # data - {"affine": np.diag([-3.0, 0.2, 1.5, 1])}, - np.array([[[0, 0], [4, 0], [8, 0]], [[12, 0], [16, 0], [20, 0]]]), - ], - [ - {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, - np.arange(24).reshape((2, 3, 4)), # data - {}, - np.array([[[0, 1, 2, 3], [0, 0, 0, 0]], [[12, 13, 14, 15], [0, 0, 0, 0]]]), - ], - [ - {"pixdim": (1.0, 1.0)}, - np.arange(24).reshape((2, 3, 4)), # data - {}, - np.array( - [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] - ), - ], - [ - {"pixdim": (4.0, 5.0, 6.0)}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, -5], [0, 0, 6, -6], [0, 0, 0, 1]])}, - np.arange(24).reshape((1, 2, 3, 4)), # data - ], - [ - {"pixdim": (4.0, 5.0, 6.0), "diagonal": True}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, - np.array( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - ], - [ - {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, - np.array( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - ], - [ - {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, - np.arange(24).reshape((1, 2, 3, 4)), # data - {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, - np.array( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - ], - [ - {"pixdim": (1.9, 4.0), "padding_mode": "zeros", "diagonal": True}, - np.arange(24).reshape((1, 4, 6)), # data - {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, - np.array( - [ +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p, + {"pixdim": (1.0, 1.5), "padding_mode": "zeros", "dtype": float}, + np.arange(4).reshape((1, 2, 2)) + 1.0, # data + {"affine": np.eye(4)}, + np.array([[[1.0, 1.0], [3.0, 2.0]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": 1.0, "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.eye(4)}, + np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.eye(4)}, + np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.0, 0.2, 1.5), "diagonal": False, "padding_mode": "zeros", "align_corners": True}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.array([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]])}, + np.array([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, + np.arange(24).reshape((2, 3, 4)), # data + {"affine": np.diag([-3.0, 0.2, 1.5, 1])}, + np.array([[[0, 0], [4, 0], [8, 0]], [[12, 0], [16, 0], [20, 0]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, + np.arange(24).reshape((2, 3, 4)), # data + {}, + np.array([[[0, 1, 2, 3], [0, 0, 0, 0]], [[12, 13, 14, 15], [0, 0, 0, 0]]]), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.0, 1.0)}, + np.arange(24).reshape((2, 3, 4)), # data + {}, + np.array( + [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0)}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, -5], [0, 0, 6, -6], [0, 0, 0, 1]])}, + np.arange(24).reshape((1, 2, 3, 4)), # data + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0), "diagonal": True}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, 4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, + np.array( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]])}, + np.array( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, + np.arange(24).reshape((1, 2, 3, 4)), # data + {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, + np.array( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (1.9, 4.0), "padding_mode": "zeros", "diagonal": True}, + np.arange(24).reshape((1, 4, 6)), # data + {"affine": np.array([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "nearest"}, + np.array( [ - [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], - [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], - [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], - [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], - [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], - [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], - [0.0, 1.0, 2.0, 2.0, 3.0, 4.0, 5.0], + [ + [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], + [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], + [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], + [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], + [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], + [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], + [0.0, 1.0, 2.0, 2.0, 3.0, 4.0, 5.0], + ] ] - ] - ), - ], - [ - {"pixdim": (5.0, 3.0), "padding_mode": "border", "diagonal": True, "dtype": np.float32}, - np.arange(24).reshape((1, 4, 6)), # data - {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, - np.array( - [ + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (5.0, 3.0), "padding_mode": "border", "diagonal": True, "dtype": np.float32}, + np.arange(24).reshape((1, 4, 6)), # data + {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, + np.array( [ - [18.0, 18.6, 19.2, 19.8, 20.400002, 21.0, 21.6, 22.2, 22.8], - [10.5, 11.1, 11.700001, 12.299999, 12.900001, 13.5, 14.1, 14.700001, 15.3], - [3.0, 3.6000001, 4.2000003, 4.8, 5.4000006, 6.0, 6.6000004, 7.200001, 7.8], + [ + [18.0, 18.6, 19.2, 19.8, 20.400002, 21.0, 21.6, 22.2, 22.8], + [10.5, 11.1, 11.700001, 12.299999, 12.900001, 13.5, 14.1, 14.700001, 15.3], + [3.0, 3.6000001, 4.2000003, 4.8, 5.4000006, 6.0, 6.6000004, 7.200001, 7.8], + ] ] - ] - ), - ], - [ - {"pixdim": (5.0, 3.0), "padding_mode": "zeros", "diagonal": True, "dtype": np.float32}, - np.arange(24).reshape((1, 4, 6)), # data - {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, - np.array( - [ + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": (5.0, 3.0), "padding_mode": "zeros", "diagonal": True, "dtype": np.float32}, + np.arange(24).reshape((1, 4, 6)), # data + {"affine": np.array([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), "mode": "bilinear"}, + np.array( [ - [18.0000, 18.6000, 19.2000, 19.8000, 20.4000, 21.0000, 21.6000, 22.2000, 22.8000], - [10.5000, 11.1000, 11.7000, 12.3000, 12.9000, 13.5000, 14.1000, 14.7000, 15.3000], - [3.0000, 3.6000, 4.2000, 4.8000, 5.4000, 6.0000, 6.6000, 7.2000, 7.8000], + [ + [18.0000, 18.6000, 19.2000, 19.8000, 20.4000, 21.0000, 21.6000, 22.2000, 22.8000], + [10.5000, 11.1000, 11.7000, 12.3000, 12.9000, 13.5000, 14.1000, 14.7000, 15.3000], + [3.0000, 3.6000, 4.2000, 4.8000, 5.4000, 6.0000, 6.6000, 7.2000, 7.8000], + ] ] - ] - ), - ], - [ - {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float}, - np.ones((1, 2, 1, 2)), # data - {"affine": np.eye(4)}, - np.array([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]), - ], -] + ), + ] + ) + TESTS.append( + [ + p, + {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2)), # data + {"affine": np.eye(4)}, + np.array([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]), + ] + ) class TestSpacingCase(unittest.TestCase): - @parameterized.expand(TEST_CASES) - def test_spacing(self, init_param, img, data_param, expected_output): - res = Spacing(**init_param)(img, **data_param) - if not isinstance(res, tuple): - np.testing.assert_allclose(res, expected_output, atol=1e-6) - return - np.testing.assert_allclose(res[0], expected_output, atol=1e-6) - sr = len(res[0].shape) - 1 + @parameterized.expand(TESTS) + def test_spacing(self, in_type, init_param, img, data_param, expected_output): + _img = in_type(img) + output_data, _, new_affine = Spacing(**init_param)(_img, **data_param) + if isinstance(_img, torch.Tensor): + self.assertEqual(_img.device, output_data.device) + output_data = output_data.cpu() + + np.testing.assert_allclose(output_data, expected_output, atol=1e-3, rtol=1e-3) + sr = len(output_data.shape) - 1 if isinstance(init_param["pixdim"], float): init_param["pixdim"] = [init_param["pixdim"]] * sr init_pixdim = ensure_tuple(init_param["pixdim"]) init_pixdim = init_param["pixdim"][:sr] - norm = np.sqrt(np.sum(np.square(res[2]), axis=0))[:sr] + norm = np.sqrt(np.sum(np.square(new_affine), axis=0))[:sr] np.testing.assert_allclose(fall_back_tuple(init_pixdim, norm), norm) diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py index 61a4a4c38b..fd1ee7fd54 100644 --- a/tests/test_spacingd.py +++ b/tests/test_spacingd.py @@ -10,82 +10,95 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np +import torch +from parameterized import parameterized from monai.transforms import Spacingd +from tests.utils import TEST_NDARRAYS - -class TestSpacingDCase(unittest.TestCase): - def test_spacingd_3d(self): - data = {"image": np.ones((2, 10, 15, 20)), "image_meta_dict": {"affine": np.eye(4)}} - spacing = Spacingd(keys="image", pixdim=(1, 2, 1.4)) - res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) - np.testing.assert_allclose(res["image"].shape, (2, 10, 8, 15)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag([1, 2, 1.4, 1.0])) - - def test_spacingd_2d(self): - data = {"image": np.ones((2, 10, 20)), "image_meta_dict": {"affine": np.eye(3)}} - spacing = Spacingd(keys="image", pixdim=(1, 2)) - res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) - np.testing.assert_allclose(res["image"].shape, (2, 10, 10)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 2, 1))) - - def test_spacingd_2d_no_metadata(self): - data = {"image": np.ones((2, 10, 20))} - spacing = Spacingd(keys="image", pixdim=(1, 2)) - res = spacing(data) - self.assertEqual(("image", "image_meta_dict", "image_transforms"), tuple(sorted(res))) - np.testing.assert_allclose(res["image"].shape, (2, 10, 10)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 2, 1))) - - def test_interp_all(self): - data = { - "image": np.arange(20).reshape((2, 1, 10)), - "seg": np.ones((2, 1, 10)), - "image_meta_dict": {"affine": np.eye(4)}, - "seg_meta_dict": {"affine": np.eye(4)}, - } - spacing = Spacingd( - keys=("image", "seg"), - mode="nearest", - pixdim=( - 1, - 0.2, - ), +TESTS: List[Tuple] = [] +for p in TEST_NDARRAYS: + TESTS.append( + ( + "spacing 3d", + {"image": p(np.ones((2, 10, 15, 20))), "image_meta_dict": {"affine": p(np.eye(4))}}, + dict(keys="image", pixdim=(1, 2, 1.4)), + ("image", "image_meta_dict", "image_transforms"), + (2, 10, 8, 15), + np.diag([1, 2, 1.4, 1.0]), ) - res = spacing(data) - self.assertEqual( - ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), - tuple(sorted(res)), + ) + TESTS.append( + ( + "spacing 2d", + {"image": np.ones((2, 10, 20)), "image_meta_dict": {"affine": np.eye(3)}}, + dict(keys="image", pixdim=(1, 2)), + ("image", "image_meta_dict", "image_transforms"), + (2, 10, 10), + np.diag((1, 2, 1)), ) - np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) - - def test_interp_sep(self): - data = { - "image": np.ones((2, 1, 10)), - "seg": np.ones((2, 1, 10)), - "image_meta_dict": {"affine": np.eye(4)}, - "seg_meta_dict": {"affine": np.eye(4)}, - } - spacing = Spacingd( - keys=("image", "seg"), - mode=("bilinear", "nearest"), - pixdim=( - 1, - 0.2, + ) + TESTS.append( + ( + "spacing 2d no metadata", + {"image": np.ones((2, 10, 20))}, + dict(keys="image", pixdim=(1, 2)), + ("image", "image_meta_dict", "image_transforms"), + (2, 10, 10), + np.diag((1, 2, 1)), + ) + ) + TESTS.append( + ( + "interp all", + { + "image": np.arange(20).reshape((2, 1, 10)), + "seg": np.ones((2, 1, 10)), + "image_meta_dict": {"affine": np.eye(4)}, + "seg_meta_dict": {"affine": np.eye(4)}, + }, + dict( + keys=("image", "seg"), + mode="nearest", + pixdim=( + 1, + 0.2, + ), ), + ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), + (2, 1, 46), + np.diag((1, 0.2, 1, 1)), ) - res = spacing(data) - self.assertEqual( + ) + TESTS.append( + ( + "interp sep", + { + "image": np.ones((2, 1, 10)), + "seg": np.ones((2, 1, 10)), + "image_meta_dict": {"affine": np.eye(4)}, + "seg_meta_dict": {"affine": np.eye(4)}, + }, + dict(keys=("image", "seg"), mode=("bilinear", "nearest"), pixdim=(1, 0.2)), ("image", "image_meta_dict", "image_transforms", "seg", "seg_meta_dict", "seg_transforms"), - tuple(sorted(res)), + (2, 1, 46), + np.diag((1, 0.2, 1, 1)), ) - np.testing.assert_allclose(res["image"].shape, (2, 1, 46)) - np.testing.assert_allclose(res["image_meta_dict"]["affine"], np.diag((1, 0.2, 1, 1))) + ) + + +class TestSpacingDCase(unittest.TestCase): + @parameterized.expand(TESTS) + def test_spacingd(self, _, data, kw_args, expected_keys, expected_shape, expected_affine): + res = Spacingd(**kw_args)(data) + if isinstance(data["image"], torch.Tensor): + self.assertEqual(data["image"].device, res["image"].device) + self.assertEqual(expected_keys, tuple(sorted(res))) + np.testing.assert_allclose(res["image"].shape, expected_shape) + np.testing.assert_allclose(res["image_meta_dict"]["affine"], expected_affine) if __name__ == "__main__": diff --git a/tests/test_to_affine_nd.py b/tests/test_to_affine_nd.py new file mode 100644 index 0000000000..087f006ddd --- /dev/null +++ b/tests/test_to_affine_nd.py @@ -0,0 +1,37 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data.utils import to_affine_nd +from tests.test_dtype_convert import TEST_NDARRAYS + +TESTS = [] +TESTS.append([2, np.eye(4)]) + + +class TestToAffinend(unittest.TestCase): + @parameterized.expand(TESTS) + def test_to_affine_nd(self, r, affine): + outs = [] + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + res = to_affine_nd(p(r), q(affine)) + outs.append(res.cpu() if isinstance(res, torch.Tensor) else res) + np.testing.assert_allclose(outs[-1], outs[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_zoom_affine.py b/tests/test_zoom_affine.py index 49c3c0dcac..7c39e0854f 100644 --- a/tests/test_zoom_affine.py +++ b/tests/test_zoom_affine.py @@ -10,69 +10,105 @@ # limitations under the License. import unittest +from typing import List, Tuple import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.data.utils import zoom_affine +from tests.utils import TEST_NDARRAYS -VALID_CASES = [ - ( - np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), - (10, 20, 30), - np.array([[8.94427191, -8.94427191, 0], [-4.47213595, -17.88854382, 0], [0.0, 0.0, 1.0]]), - ), - ( - np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), - (10, 20, 30), - np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 30, 0], [0, 0, 0, 1]]), - ), - ( - np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), - (10, 20), - np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 3, 0], [0, 0, 0, 1]]), - ), - ( - np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), - (10,), - np.array([[10, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 1]]), - ), - ( - [[1, 0, 10], [0, 1, 20], [0, 0, 1]] - @ ([[0, -1, 0], [1, 0, 0], [0, 0, 1]] @ np.array([[2, 0.3, 0], [0, 3, 0], [0, 0, 1]])), - (4, 5, 6), - ([[0, -1, 0], [1, 0, 0], [0, 0, 1]] @ np.array([[4, 0, 0], [0, 5, 0], [0, 0, 1]])), - ), -] +VALID_CASES: List[Tuple] = [] +DIAGONAL_CASES: List[Tuple] = [] +for p in TEST_NDARRAYS: + VALID_CASES.append( + ( + p, + np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), + (10, 20, 30), + np.array([[8.94427191, -8.94427191, 0], [-4.47213595, -17.88854382, 0], [0.0, 0.0, 1.0]]), + ) + ) + VALID_CASES.append( + ( + p, + np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), + (10, 20, 30), + np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 30, 0], [0, 0, 0, 1]]), + ) + ) + VALID_CASES.append( + ( + p, + np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), + (10, 20), + np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 3, 0], [0, 0, 0, 1]]), + ) + ) + VALID_CASES.append( + ( + p, + np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), + (10,), + np.array([[10, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 1]]), + ) + ) + VALID_CASES.append( + ( + p, + [[1, 0, 10], [0, 1, 20], [0, 0, 1]] + @ ([[0, -1, 0], [1, 0, 0], [0, 0, 1]] @ np.array([[2, 0.3, 0], [0, 3, 0], [0, 0, 1]])), + (4, 5, 6), + ([[0, -1, 0], [1, 0, 0], [0, 0, 1]] @ np.array([[4, 0, 0], [0, 5, 0], [0, 0, 1]])), + ) + ) -DIAGONAL_CASES = [ - ( - np.array([[-1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), - (10, 20, 30), - np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 30, 0], [0, 0, 0, 1]]), - ), - (np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), (10, 20, 30), np.array([[10, 0, 0], [0, 20, 0], [0.0, 0.0, 1.0]])), - ( # test default scale from affine - np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), - (10,), - np.array([[10, 0, 0], [0, 3.162278, 0], [0.0, 0.0, 1.0]]), - ), -] + DIAGONAL_CASES.append( + ( + p, + np.array([[-1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]), + (10, 20, 30), + np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 30, 0], [0, 0, 0, 1]]), + ) + ) + DIAGONAL_CASES.append( + ( + p, + np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), + (10, 20, 30), + np.array([[10, 0, 0], [0, 20, 0], [0.0, 0.0, 1.0]]), + ) + ) + DIAGONAL_CASES.append( + ( # test default scale from affine + p, + np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), + (10,), + np.array([[10, 0, 0], [0, 3.162278, 0], [0.0, 0.0, 1.0]]), + ) + ) class TestZoomAffine(unittest.TestCase): @parameterized.expand(VALID_CASES) - def test_correct(self, affine, scale, expected): - output = zoom_affine(affine, scale, diagonal=False) + def test_correct(self, in_type, affine, scale, expected): + _affine = in_type(affine) + output = zoom_affine(_affine, scale, diagonal=False) + if isinstance(_affine, torch.Tensor): + self.assertEqual(_affine.device, output.device) + output = output.cpu() ornt_affine = nib.orientations.ornt2axcodes(nib.orientations.io_orientation(output)) ornt_output = nib.orientations.ornt2axcodes(nib.orientations.io_orientation(affine)) np.testing.assert_array_equal(ornt_affine, ornt_output) np.testing.assert_allclose(output, expected, rtol=1e-6, atol=1e-6) @parameterized.expand(DIAGONAL_CASES) - def test_diagonal(self, affine, scale, expected): - output = zoom_affine(affine, scale, diagonal=True) + def test_diagonal(self, in_type, affine, scale, expected): + output = zoom_affine(in_type(affine), scale, diagonal=True) + if isinstance(output, torch.Tensor): + output = output.cpu().numpy() np.testing.assert_allclose(output, expected, rtol=1e-6, atol=1e-6) From f75bb5182ef6c64474cb70208ad06c326670fb24 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 15 Jul 2021 16:25:02 +0100 Subject: [PATCH 107/176] pytorch <= 1.5 support Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 83199ee493..a060fa4fc2 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -309,6 +309,12 @@ def map_classes_to_indices( """ img_flat: Optional[DataObjects.Images] = None + # if older pytorch, switch to numpy + if not hasattr(label, "ravel"): + label, *_ = convert_data_type(label, np.ndarray) + if image is not None: + image, *_ = convert_data_type(image, np.ndarray) + if image is not None: img_flat = (image > image_threshold).any(axis=0).ravel() # type: ignore @@ -323,9 +329,7 @@ def map_classes_to_indices( num_classes_ = num_classes for c in range(num_classes_): - label_flat = (label[c : c + 1] if channels > 1 else label == c).any(axis=0) # type: ignore - # ravel with support for pytorch1.6 (flatten) - label_flat = label_flat.ravel() if hasattr(label_flat, "ravel") else label_flat.flatten() + label_flat = (label[c : c + 1] if channels > 1 else label == c).any(axis=0).ravel() # type: ignore label_flat = img_flat & label_flat if img_flat is not None else label_flat idx = label_flat.nonzero() indices.append(idx[0] if isinstance(idx, tuple) else idx.squeeze(1)) From 79157fa39e3af86aee5ee40565bda95a97759955 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 15 Jul 2021 16:58:29 +0100 Subject: [PATCH 108/176] ignore Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index e8e6c51eb6..6eef7742d7 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -189,8 +189,8 @@ def __call__( affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) - affine_, *_ = convert_data_type(affine_, np.ndarray) + affine_ = to_affine_nd(sr, affine) # type: ignore + affine_, *_ = convert_data_type(affine_, np.ndarray) # type: ignore out_d = self.pixdim[:sr] if out_d.size < sr: @@ -293,7 +293,7 @@ def __call__( affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: - affine_ = to_affine_nd(sr, affine) + affine_ = to_affine_nd(sr, affine) # type: ignore src = nib.io_orientation(affine_) if self.as_closest_canonical: spatial_ornt = src From f659b97565ef2934ee5e5009f4f22994129a2bb6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 16 Jul 2021 09:36:15 +0100 Subject: [PATCH 109/176] fix unit tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/nifti_writer.py | 2 ++ tests/test_zoom_affine.py | 5 ----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index 9af4d396ef..17740a27bc 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -103,6 +103,8 @@ def write_nifti( # if torch, convert to numpy data_np: np.ndarray data_np, *_ = convert_data_type(data, np.ndarray) # type: ignore + if target_affine is not None: + target_affine, *_ = convert_data_type(target_affine, np.ndarray) # type: ignore dtype = dtype or data_np.dtype sr = min(data_np.ndim, 3) diff --git a/tests/test_zoom_affine.py b/tests/test_zoom_affine.py index 7c39e0854f..ba28fd5b05 100644 --- a/tests/test_zoom_affine.py +++ b/tests/test_zoom_affine.py @@ -96,9 +96,6 @@ class TestZoomAffine(unittest.TestCase): def test_correct(self, in_type, affine, scale, expected): _affine = in_type(affine) output = zoom_affine(_affine, scale, diagonal=False) - if isinstance(_affine, torch.Tensor): - self.assertEqual(_affine.device, output.device) - output = output.cpu() ornt_affine = nib.orientations.ornt2axcodes(nib.orientations.io_orientation(output)) ornt_output = nib.orientations.ornt2axcodes(nib.orientations.io_orientation(affine)) np.testing.assert_array_equal(ornt_affine, ornt_output) @@ -107,8 +104,6 @@ def test_correct(self, in_type, affine, scale, expected): @parameterized.expand(DIAGONAL_CASES) def test_diagonal(self, in_type, affine, scale, expected): output = zoom_affine(in_type(affine), scale, diagonal=True) - if isinstance(output, torch.Tensor): - output = output.cpu().numpy() np.testing.assert_allclose(output, expected, rtol=1e-6, atol=1e-6) From 4447c1abfe4fdd75454505e2238bd5ee227ab19c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 16 Jul 2021 09:45:19 +0100 Subject: [PATCH 110/176] remove torch Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_zoom_affine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_zoom_affine.py b/tests/test_zoom_affine.py index ba28fd5b05..7d21341a01 100644 --- a/tests/test_zoom_affine.py +++ b/tests/test_zoom_affine.py @@ -14,7 +14,6 @@ import nibabel as nib import numpy as np -import torch from parameterized import parameterized from monai.data.utils import zoom_affine From 1ad0e4d5a72230e94df9720567003dcc57308d4b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 16 Jul 2021 16:46:36 +0100 Subject: [PATCH 111/176] resample, affine, affinegrid, elastic2d, elastic3d Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 102 +++---- monai/transforms/spatial/dictionary.py | 24 +- tests/test_affine.py | 182 +++++++++---- tests/test_affine_grid.py | 164 ++++++----- tests/test_affined.py | 195 ++++++++----- tests/test_inverse_collation.py | 2 - tests/test_rand_affine.py | 224 ++++++++------- tests/test_rand_affine_grid.py | 326 +++++++++++----------- tests/test_rand_affined.py | 362 +++++++++++++------------ tests/test_rand_elastic_2d.py | 153 ++++++----- tests/test_rand_elastic_3d.py | 132 +++++---- tests/test_rand_elasticd_2d.py | 250 +++++++++-------- tests/test_rand_elasticd_3d.py | 214 ++++++++------- tests/test_resampler.py | 161 +++++++---- 14 files changed, 1385 insertions(+), 1106 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 6eef7742d7..273281ce1b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -952,7 +952,7 @@ def __call__( return zoomer(img) -class AffineGrid(ToDoTransform): +class AffineGrid(TorchTransform): """ Affine transforms on the coordinates. @@ -987,7 +987,6 @@ def __init__( shear_params: Optional[Union[Sequence[float], float]] = None, translate_params: Optional[Union[Sequence[float], float]] = None, scale_params: Optional[Union[Sequence[float], float]] = None, - as_tensor_output: bool = True, device: Optional[torch.device] = None, affine: Optional[DataObjects.Images] = None, ) -> None: @@ -995,10 +994,7 @@ def __init__( self.shear_params = shear_params self.translate_params = translate_params self.scale_params = scale_params - - self.as_tensor_output = as_tensor_output self.device = device - self.affine = affine def __call__( @@ -1021,7 +1017,7 @@ def __call__( else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") - affine: Union[torch.Tensor, np.ndarray] + affine: DataObjects.Images if self.affine is None: spatial_dims = len(grid.shape) - 1 affine = np.eye(spatial_dims + 1) @@ -1036,17 +1032,14 @@ def __call__( else: affine = self.affine - if isinstance(affine, np.ndarray): - affine = torch.as_tensor(np.ascontiguousarray(affine)) + grid, orig_type, orig_device = convert_data_type(grid, torch.Tensor, dtype=float, device=self.device) + affine, *_ = convert_data_type(affine, torch.Tensor, dtype=float, device=grid.device) # type: ignore - grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() - if self.device: - affine = affine.to(self.device) - grid = grid.to(self.device) - grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) + grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) if grid is None or not isinstance(grid, torch.Tensor): raise ValueError("Unknown grid.") - return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()), affine + grid, *_ = convert_data_type(grid, orig_type, orig_device) + return grid, affine class RandAffineGrid(Randomizable, ToDoTransform): @@ -1060,7 +1053,6 @@ def __init__( shear_range: RandRange = None, translate_range: RandRange = None, scale_range: RandRange = None, - as_tensor_output: bool = True, device: Optional[torch.device] = None, ) -> None: """ @@ -1075,8 +1067,6 @@ def __init__( translate_range: translate_range with format matching `rotate_range`. scale_range: scaling_range with format matching `rotate_range`. A value of 1.0 is added to the result. This allows 0 to correspond to no change (i.e., a scaling of 1). - as_tensor_output: whether to output tensor instead of numpy array. - defaults to True. device: device to store the output grid data. See also: @@ -1095,7 +1085,6 @@ def __init__( self.translate_params: Optional[List[float]] = None self.scale_params: Optional[List[float]] = None - self.as_tensor_output = as_tensor_output self.device = device self.affine: Optional[DataObjects.Images] = None @@ -1135,7 +1124,6 @@ def __call__( shear_params=self.shear_params, translate_params=self.translate_params, scale_params=self.scale_params, - as_tensor_output=self.as_tensor_output, device=self.device, ) grid, self.affine = affine_grid(spatial_size, grid) @@ -1146,7 +1134,7 @@ def get_transformation_matrix(self) -> Optional[DataObjects.Images]: return self.affine -class RandDeformGrid(Randomizable, ToDoTransform): +class RandDeformGrid(Randomizable, Transform): """ Generate random deformation grid. """ @@ -1196,12 +1184,11 @@ def __call__(self, spatial_size: Sequence[int]): return control_grid -class Resample(ToDoTransform): +class Resample(TorchTransform): def __init__( self, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, - as_tensor_output: bool = False, device: Optional[torch.device] = None, ) -> None: """ @@ -1215,12 +1202,10 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: whether to return a torch tensor. Defaults to False. device: device on which the tensor will be allocated. """ self.mode: GridSampleMode = GridSampleMode(mode) self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) - self.as_tensor_output = as_tensor_output self.device = device def __call__( @@ -1241,18 +1226,15 @@ def __call__( Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ - - if not isinstance(img, torch.Tensor): - img = torch.as_tensor(np.ascontiguousarray(img)) if grid is None: raise AssertionError("Error, grid argument must be supplied as an ndarray or tensor ") - grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() - if self.device: - img = img.to(self.device) - grid = grid.to(self.device) + + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, device=self.device, dtype=float) # type: ignore + grid, *_ = convert_data_type(deepcopy(grid), torch.Tensor, device=img_t.device, dtype=float) if USE_COMPILED: - for i, dim in enumerate(img.shape[1:]): + for i, dim in enumerate(img_t.shape[1:]): grid[i] += (dim - 1.0) / 2.0 grid = grid[:-1] / grid[-1:] grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) @@ -1265,32 +1247,32 @@ def __call__( bound = 1 _interp_mode = self.mode.value if mode is None else GridSampleMode(mode).value out = grid_pull( - img.unsqueeze(0).float(), + img_t.unsqueeze(0).float(), grid.unsqueeze(0).float(), bound=bound, extrapolate=True, interpolation=1 if _interp_mode == "bilinear" else _interp_mode, )[0] else: - for i, dim in enumerate(img.shape[1:]): + for i, dim in enumerate(img_t.shape[1:]): grid[i] = 2.0 * grid[i] / (dim - 1.0) grid = grid[:-1] / grid[-1:] - index_ordering: List[int] = list(range(img.ndimension() - 2, -1, -1)) + index_ordering: List[int] = list(range(img_t.ndimension() - 2, -1, -1)) grid = grid[index_ordering] grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) out = torch.nn.functional.grid_sample( - img.unsqueeze(0).float(), + img_t.unsqueeze(0).float(), grid.unsqueeze(0).float(), mode=self.mode.value if mode is None else GridSampleMode(mode).value, padding_mode=self.padding_mode.value if padding_mode is None else GridSamplePadMode(padding_mode).value, align_corners=True, )[0] - if self.as_tensor_output: - return torch.as_tensor(out) - return np.asarray(out.cpu().numpy()) + + out, *_ = convert_data_type(out, orig_type, orig_device) + return out # type: ignore -class Affine(Transform): +class Affine(TorchTransform): """ Transform ``img`` given the affine parameters. """ @@ -1304,7 +1286,6 @@ def __init__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, image_only: bool = False, ) -> None: @@ -1330,8 +1311,6 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. image_only: if True return only the image volume, otherwise return (image, affine). """ @@ -1340,11 +1319,10 @@ def __init__( shear_params=shear_params, translate_params=translate_params, scale_params=scale_params, - as_tensor_output=True, device=device, ) self.image_only = image_only - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.mode: GridSampleMode = GridSampleMode(mode) self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) @@ -1378,7 +1356,7 @@ def __call__( return ret if self.image_only else (ret, affine) -class RandAffine(RandomizableTransform): +class RandAffine(RandomizableTransform, TorchTransform): """ Random affine transform. """ @@ -1394,7 +1372,6 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, cache_grid: bool = False, - as_tensor_output: bool = True, device: Optional[torch.device] = None, ) -> None: """ @@ -1426,8 +1403,6 @@ def __init__( cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: @@ -1441,10 +1416,9 @@ def __init__( shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, - as_tensor_output=True, device=device, ) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.cache_grid = cache_grid @@ -1527,8 +1501,8 @@ def __call__( sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) do_resampling = self._do_transform or (sp_size != ensure_tuple(img.shape[1:])) if not do_resampling: - img = img.float() if isinstance(img, torch.Tensor) else img.astype("float32") - return torch.Tensor(img) if self.resampler.as_tensor_output else np.array(img) + img, *_ = convert_data_type(img, dtype=np.float32) + return img grid = self.get_identity_grid(sp_size) if self._do_transform: grid = self.rand_affine_grid(grid=grid) @@ -1537,7 +1511,7 @@ def __call__( ) -class Rand2DElastic(RandomizableTransform): +class Rand2DElastic(TorchTransform, RandomizableTransform): """ Random elastic deformation and affine in 2D """ @@ -1554,7 +1528,6 @@ def __init__( spatial_size: Optional[Union[Tuple[int, int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, ) -> None: """ @@ -1586,8 +1559,6 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: @@ -1603,10 +1574,9 @@ def __init__( shear_range=shear_range, translate_range=translate_range, scale_range=scale_range, - as_tensor_output=True, device=device, ) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.resampler = Resample(device=device) self.spatial_size = spatial_size self.mode: GridSampleMode = GridSampleMode(mode) @@ -1663,7 +1633,7 @@ def __call__( return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) -class Rand3DElastic(RandomizableTransform): +class Rand3DElastic(TorchTransform, RandomizableTransform): """ Random elastic deformation and affine in 3D """ @@ -1680,7 +1650,6 @@ def __init__( spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, ) -> None: """ @@ -1714,8 +1683,6 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. See also: @@ -1723,8 +1690,8 @@ def __init__( - :py:class:`Affine` for the affine transformation parameters configurations. """ RandomizableTransform.__init__(self, prob) - self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, True, device) - self.resampler = Resample(as_tensor_output=as_tensor_output, device=device) + self.rand_affine_grid = RandAffineGrid(rotate_range, shear_range, translate_range, scale_range, device) + self.resampler = Resample(device=device) self.sigma_range = sigma_range self.magnitude_range = magnitude_range @@ -1778,9 +1745,10 @@ def __call__( if self._do_transform: if self.rand_offset is None: raise AssertionError - grid = torch.as_tensor(np.ascontiguousarray(grid), device=self.device) + grid, *_ = convert_data_type(grid, torch.Tensor, device=self.device) + offset, *_ = convert_data_type(self.rand_offset, torch.Tensor, device=self.device) + offset = offset.unsqueeze(0) # type: ignore gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=self.device) - offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.magnitude grid = self.rand_affine_grid(grid=grid) return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 450bba0146..a39b0526bb 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -571,7 +571,6 @@ def __init__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -598,8 +597,6 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -614,7 +611,6 @@ def __init__( translate_params=translate_params, scale_params=scale_params, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -681,7 +677,6 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, cache_grid: bool = False, - as_tensor_output: bool = True, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -717,8 +712,6 @@ def __init__( cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -736,7 +729,6 @@ def __init__( scale_range=scale_range, spatial_size=spatial_size, cache_grid=cache_grid, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -767,7 +759,7 @@ def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Imag if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) if self._do_transform: # add some random factors - grid = self.rand_affine.rand_affine_grid(grid=grid) + grid = self.rand_affine.rand_affine_grid(grid=grid).clone() # type: ignore affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() # type: ignore[assignment] for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): @@ -783,12 +775,6 @@ def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Imag # do the transform if do_resampling: d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) - # if not doing transform and and spatial size is unchanged, only need to do numpy/torch conversion - else: - if self.rand_affine.resampler.as_tensor_output and not isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]) - elif not self.rand_affine.resampler.as_tensor_output and isinstance(d[key], torch.Tensor): - d[key] = d[key].detach().cpu().numpy() # type: ignore[union-attr] return d @@ -841,7 +827,6 @@ def __init__( scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -878,8 +863,6 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -898,7 +881,6 @@ def __init__( translate_range=translate_range, scale_range=scale_range, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) @@ -958,7 +940,6 @@ def __init__( scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, - as_tensor_output: bool = False, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -996,8 +977,6 @@ def __init__( Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. - as_tensor_output: the computation is implemented using pytorch tensors, this option specifies - whether to convert it back to numpy arrays. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -1016,7 +995,6 @@ def __init__( translate_range=translate_range, scale_range=scale_range, spatial_size=spatial_size, - as_tensor_output=as_tensor_output, device=device, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) diff --git a/tests/test_affine.py b/tests/test_affine.py index dd82d72e23..b33fce8104 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -16,77 +16,143 @@ from parameterized import parameterized from monai.transforms import Affine +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None, image_only=True), - {"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape(1, 2, 2), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2)), "spatial_size": (4, 4)}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(rotate_params=[np.pi / 2], padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2)), "spatial_size": (4, 4)}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (-1, 0, 0)}, - np.arange(27).reshape(1, 3, 3, 3), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2)), "spatial_size": (4, 4, 4)}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(9).reshape((1, 3, 3))), "spatial_size": (-1, 0)}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], - [ - dict(rotate_params=[np.pi / 2], padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2)), "spatial_size": (4, 4, 4)}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 6.0, 4.0, 0.0], [0.0, 7.0, 5.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=device, image_only=True), + {"img": p(np.arange(9).reshape((1, 3, 3))), "spatial_size": (-1, 0)}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], -] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape(1, 2, 2)), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(27).reshape((1, 3, 3, 3))), "spatial_size": (-1, 0, 0)}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "spatial_size": (4, 4, 4)}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict(rotate_params=[np.pi / 2], padding_mode="zeros", device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "spatial_size": (4, 4, 4)}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 4.0, 0.0], + [0.0, 7.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) class TestAffine(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affine(**input_param) result = g(**input_data) if isinstance(result, tuple): result = result[0] - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) + self.assertEqual(type(result), type(expected_val)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, expected_val.device) + result = result.cpu() + expected_val = expected_val.cpu() np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index 24772b9a21..23e12ba6f0 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -16,88 +16,110 @@ from parameterized import parameterized from monai.transforms import AffineGrid +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"as_tensor_output": False, "device": torch.device("cpu:0")}, - {"spatial_size": (2, 2)}, - np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), - ], - [ - {"as_tensor_output": True, "device": None}, - {"spatial_size": (2, 2)}, - torch.tensor([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), - ], - [{"as_tensor_output": False, "device": None}, {"grid": np.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [{"as_tensor_output": True, "device": torch.device("cpu:0")}, {"grid": np.ones((3, 3, 3))}, torch.ones((3, 3, 3))], - [{"as_tensor_output": False, "device": None}, {"grid": torch.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [ - {"as_tensor_output": True, "device": torch.device("cpu:0")}, - {"grid": torch.ones((3, 3, 3))}, - torch.ones((3, 3, 3)), - ], - [ - { - "rotate_params": (1.0, 1.0), - "scale_params": (-20, 10), - "as_tensor_output": True, - "device": torch.device("cpu:0"), - }, - {"grid": torch.ones((3, 3, 3))}, - torch.tensor( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[-19.2208, -19.2208, -19.2208], [-19.2208, -19.2208, -19.2208], [-19.2208, -19.2208, -19.2208]], - [[-11.4264, -11.4264, -11.4264], [-11.4264, -11.4264, -11.4264], [-11.4264, -11.4264, -11.4264]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + {"device": device}, + {"spatial_size": (2, 2)}, + np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]), ] - ), - ], - [ - { - "rotate_params": (1.0, 1.0, 1.0), - "scale_params": (-20, 10), - "as_tensor_output": True, - "device": torch.device("cpu:0"), - }, - {"grid": torch.ones((4, 3, 3, 3))}, - torch.tensor( + ) + + TESTS.append([{"device": device}, {"grid": p(np.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append([{"device": device}, {"grid": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append( + [ + { + "rotate_params": (1.0, 1.0), + "scale_params": (-20, 10), + "device": device, + }, + {"grid": p(torch.ones((3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + ], + [ + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + ], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ) + ), + ] + ) + TESTS.append( [ - [ - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], - ], - [ - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - [[-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381], [-20.2381, -20.2381, -20.2381]], - ], - [ - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], - ], - [ - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], - ], + { + "rotate_params": (1.0, 1.0, 1.0), + "scale_params": (-20, 10), + "device": device, + }, + {"grid": p(torch.ones((4, 3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]], + ], + [ + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + [ + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + [-20.2381, -20.2381, -20.2381], + ], + ], + [ + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]], + ], + [ + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]], + ], + ] + ) + ), ] - ), - ], -] + ) class TestAffineGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine_grid(self, input_param, input_data, expected_val): g = AffineGrid(**input_param) result, _ = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "grid" in input_data: + self.assertEqual(type(result), type(expected_val)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, expected_val.device) + result = result.cpu() + expected_val = expected_val.cpu() + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_affined.py b/tests/test_affined.py index 850f12905d..5bbb384537 100644 --- a/tests/test_affined.py +++ b/tests/test_affined.py @@ -16,84 +16,145 @@ from parameterized import parameterized from monai.transforms import Affined +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - dict(keys="img", padding_mode="zeros", as_tensor_output=False, spatial_size=(-1, 0), device=None), - {"img": np.arange(9).reshape((1, 3, 3))}, - np.arange(9).reshape(1, 3, 3), - ], - [ - dict(keys="img", padding_mode="zeros", as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape(1, 2, 2), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(4, 4), as_tensor_output=False, device=None), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict( - keys="img", - rotate_params=[np.pi / 2], - padding_mode="zeros", - spatial_size=(4, 4), - as_tensor_output=False, - device=None, - ), - {"img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), as_tensor_output=False, device=None), - {"img": np.arange(27).reshape((1, 3, 3, 3))}, - np.arange(27).reshape(1, 3, 3, 3), - ], - [ - dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), as_tensor_output=False, device=None), - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0), device=device), + {"img": p(np.arange(9).reshape((1, 3, 3)))}, + p(np.arange(9).reshape(1, 3, 3)), ] - ), - ], - [ - dict( - keys="img", - rotate_params=[np.pi / 2], - padding_mode="zeros", - spatial_size=(4, 4, 4), - as_tensor_output=False, - device=None, - ), - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 6.0, 4.0, 0.0], [0.0, 7.0, 5.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(keys="img", padding_mode="zeros", device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape(1, 2, 2)), ] - ), - ], -] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(4, 4), device=device), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict( + keys="img", + rotate_params=[np.pi / 2], + padding_mode="zeros", + spatial_size=(4, 4), + device=device, + ), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), device=device), + {"img": p(np.arange(27).reshape((1, 3, 3, 3)))}, + p(np.arange(27).reshape(1, 3, 3, 3)), + ] + ) + TESTS.append( + [ + dict(keys="img", padding_mode="zeros", spatial_size=(4, 4, 4), device=device), + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict( + keys="img", + rotate_params=[np.pi / 2], + padding_mode="zeros", + spatial_size=(4, 4, 4), + device=device, + ), + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + p( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 3.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 6.0, 4.0, 0.0], + [0.0, 7.0, 5.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) class TestAffined(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affined(**input_param) result = g(input_data)["img"] - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) + self.assertEqual(type(result), type(expected_val)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, expected_val.device) + result, expected_val = result.cpu(), expected_val.cpu() np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index 816f7eb1b1..334ea5f4b3 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -56,7 +56,6 @@ prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - as_tensor_output=False, ), ] ] @@ -75,7 +74,6 @@ prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - as_tensor_output=False, ), ] ] diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py index 1e1a23bc09..1a0502df3f 100644 --- a/tests/test_rand_affine.py +++ b/tests/test_rand_affine.py @@ -16,114 +16,136 @@ from parameterized import parameterized from monai.transforms import RandAffine +from monai.utils.misc import convert_data_type +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - dict(as_tensor_output=False, device=None), - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=-1), - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None), - {"img": torch.arange(27).reshape((3, 3, 3)), "spatial_size": (2, 2)}, - np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]]), - ], - [ - dict(as_tensor_output=True, device=None), - {"img": torch.ones((1, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), cache_grid=True), - {"img": torch.ones((1, 3, 3, 3))}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - padding_mode="zeros", - spatial_size=(2, 2, 2), - device=None, - ), - {"img": torch.ones((1, 3, 3, 3)), "mode": "bilinear"}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - padding_mode="zeros", - spatial_size=(2, 2, 2), - cache_grid=True, - device=None, - ), - {"img": torch.ones((1, 3, 3, 3)), "mode": "bilinear"}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=True, - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "spatial_size": (3, 3)}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - spatial_size=(3, 3), - cache_grid=True, - as_tensor_output=True, - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8))}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + dict(device=device), + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=-1), + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device), + {"img": p(torch.arange(27).reshape((3, 3, 3))), "spatial_size": (2, 2)}, + p(np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]])), + ] + ) + TESTS.append( + [ + dict(device=device), + {"img": p(torch.ones((1, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2, 2), cache_grid=True), + {"img": p(torch.ones((1, 3, 3, 3)))}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + padding_mode="zeros", + spatial_size=(2, 2, 2), + device=device, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "mode": "bilinear"}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + padding_mode="zeros", + spatial_size=(2, 2, 2), + cache_grid=True, + device=device, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "mode": "bilinear"}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "spatial_size": (3, 3)}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + cache_grid=True, + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8)))}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) -ARR_NUMPY = np.arange(9 * 10).reshape(1, 9, 10) -ARR_TORCH = torch.Tensor(ARR_NUMPY) TEST_CASES_SKIPPED_CONSISTENCY = [] -for im in (ARR_NUMPY, ARR_TORCH): - for as_tensor_output in (True, False): - for in_dtype_is_int in (True, False): - TEST_CASES_SKIPPED_CONSISTENCY.append((im, as_tensor_output, in_dtype_is_int)) +for p in TEST_NDARRAYS: + for in_dtype in (np.int32, np.float32): + TEST_CASES_SKIPPED_CONSISTENCY.append((p(np.arange(9 * 10).reshape(1, 9, 10)), in_dtype)) class TestRandAffine(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affine(self, input_param, input_data, expected_val): g = RandAffine(**input_param) g.set_random_state(123) result = g(**input_data) if input_param.get("cache_grid", False): self.assertTrue(g._cached_grid is not None) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) + self.assertEqual(type(result), type(expected_val)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected_val.device) + result, expected_val = result.cpu(), expected_val.cpu() + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) def test_ill_cache(self): with self.assertWarns(UserWarning): @@ -132,15 +154,11 @@ def test_ill_cache(self): RandAffine(cache_grid=True, spatial_size=(1, 1, -1)) @parameterized.expand(TEST_CASES_SKIPPED_CONSISTENCY) - def test_skipped_transform_consistency(self, im, as_tensor_output, in_dtype_is_int): - t1 = RandAffine(prob=0, as_tensor_output=as_tensor_output) - t2 = RandAffine(prob=1, spatial_size=(10, 11), as_tensor_output=as_tensor_output) + def test_skipped_transform_consistency(self, im, in_dtype): + t1 = RandAffine(prob=0) + t2 = RandAffine(prob=1, spatial_size=(10, 11)) - # change dtype to int32 or float32 - if in_dtype_is_int: - im = im.astype("int32") if isinstance(im, np.ndarray) else im.int() - else: - im = im.astype("float32") if isinstance(im, np.ndarray) else im.float() + im, *_ = convert_data_type(im, dtype=in_dtype) out1 = t1(im) out2 = t2(im) diff --git a/tests/test_rand_affine_grid.py b/tests/test_rand_affine_grid.py index 605d0a30ba..3b81a924ba 100644 --- a/tests/test_rand_affine_grid.py +++ b/tests/test_rand_affine_grid.py @@ -16,182 +16,196 @@ from parameterized import parameterized from monai.transforms import RandAffineGrid +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [{"as_tensor_output": False, "device": None}, {"grid": torch.ones((3, 3, 3))}, np.ones((3, 3, 3))], - [ - {"rotate_range": (1, 2), "translate_range": (3, 3, 3)}, - {"grid": torch.arange(0, 27).reshape((3, 3, 3))}, - torch.tensor( - np.array( - [ - [ - [-32.81998, -33.910976, -35.001972], - [-36.092968, -37.183964, -38.27496], - [-39.36596, -40.456955, -41.54795], - ], - [[2.1380205, 3.1015975, 4.0651755], [5.028752, 5.9923296, 6.955907], [7.919484, 8.883063, 9.84664]], - [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]], - ] - ) - ), - ], - [ - {"translate_range": (3, 3, 3), "as_tensor_output": False, "device": torch.device("cpu:0")}, - {"spatial_size": (3, 3, 3)}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append([{"device": device}, {"grid": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))]) + TESTS.append( [ - [ - [ - [0.17881513, 0.17881513, 0.17881513], - [0.17881513, 0.17881513, 0.17881513], - [0.17881513, 0.17881513, 0.17881513], - ], - [ - [1.1788151, 1.1788151, 1.1788151], - [1.1788151, 1.1788151, 1.1788151], - [1.1788151, 1.1788151, 1.1788151], - ], - [ - [2.1788151, 2.1788151, 2.1788151], - [2.1788151, 2.1788151, 2.1788151], - [2.1788151, 2.1788151, 2.1788151], - ], - ], - [ - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - [ - [-2.283164, -2.283164, -2.283164], - [-1.283164, -1.283164, -1.283164], - [-0.28316402, -0.28316402, -0.28316402], - ], - ], - [ - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - [ - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - [-2.6388912, -1.6388912, -0.6388912], - ], - ], - [ - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], - ], - ] - ), - ], - [ - {"rotate_range": (1.0, 1.0, 1.0), "shear_range": (0.1,), "scale_range": (1.2,)}, - {"grid": torch.arange(0, 108).reshape((4, 3, 3, 3))}, - torch.tensor( - np.array( - [ - [ - [ - [-9.4201e00, -8.1672e00, -6.9143e00], - [-5.6614e00, -4.4085e00, -3.1556e00], - [-1.9027e00, -6.4980e-01, 6.0310e-01], - ], - [ - [1.8560e00, 3.1089e00, 4.3618e00], - [5.6147e00, 6.8676e00, 8.1205e00], - [9.3734e00, 1.0626e01, 1.1879e01], - ], + {"rotate_range": (1, 2), "translate_range": (3, 3, 3)}, + {"grid": p(torch.arange(0, 27).reshape((3, 3, 3)))}, + p( + np.array( [ - [1.3132e01, 1.4385e01, 1.5638e01], - [1.6891e01, 1.8144e01, 1.9397e01], - [2.0650e01, 2.1902e01, 2.3155e01], - ], - ], - [ - [ - [9.9383e-02, -4.8845e-01, -1.0763e00], - [-1.6641e00, -2.2519e00, -2.8398e00], - [-3.4276e00, -4.0154e00, -4.6032e00], - ], - [ - [-5.1911e00, -5.7789e00, -6.3667e00], - [-6.9546e00, -7.5424e00, -8.1302e00], - [-8.7180e00, -9.3059e00, -9.8937e00], - ], - [ - [-1.0482e01, -1.1069e01, -1.1657e01], - [-1.2245e01, -1.2833e01, -1.3421e01], - [-1.4009e01, -1.4596e01, -1.5184e01], - ], - ], + [ + [-32.81998, -33.910976, -35.001972], + [-36.092968, -37.183964, -38.27496], + [-39.36596, -40.456955, -41.54795], + ], + [ + [2.1380205, 3.1015975, 4.0651755], + [5.028752, 5.9923296, 6.955907], + [7.919484, 8.883063, 9.84664], + ], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]], + ] + ) + ), + ] + ) + TESTS.append( + [ + {"translate_range": (3, 3, 3), "device": device}, + {"spatial_size": (3, 3, 3)}, + np.array( [ [ - [5.9635e01, 6.1199e01, 6.2764e01], - [6.4328e01, 6.5892e01, 6.7456e01], - [6.9021e01, 7.0585e01, 7.2149e01], - ], - [ - [7.3714e01, 7.5278e01, 7.6842e01], - [7.8407e01, 7.9971e01, 8.1535e01], - [8.3099e01, 8.4664e01, 8.6228e01], + [ + [0.17881513, 0.17881513, 0.17881513], + [0.17881513, 0.17881513, 0.17881513], + [0.17881513, 0.17881513, 0.17881513], + ], + [ + [1.1788151, 1.1788151, 1.1788151], + [1.1788151, 1.1788151, 1.1788151], + [1.1788151, 1.1788151, 1.1788151], + ], + [ + [2.1788151, 2.1788151, 2.1788151], + [2.1788151, 2.1788151, 2.1788151], + [2.1788151, 2.1788151, 2.1788151], + ], ], [ - [8.7792e01, 8.9357e01, 9.0921e01], - [9.2485e01, 9.4049e01, 9.5614e01], - [9.7178e01, 9.8742e01, 1.0031e02], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], + [ + [-2.283164, -2.283164, -2.283164], + [-1.283164, -1.283164, -1.283164], + [-0.28316402, -0.28316402, -0.28316402], + ], ], - ], - [ [ - [8.1000e01, 8.2000e01, 8.3000e01], - [8.4000e01, 8.5000e01, 8.6000e01], - [8.7000e01, 8.8000e01, 8.9000e01], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], + [ + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + [-2.6388912, -1.6388912, -0.6388912], + ], ], [ - [9.0000e01, 9.1000e01, 9.2000e01], - [9.3000e01, 9.4000e01, 9.5000e01], - [9.6000e01, 9.7000e01, 9.8000e01], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], ], + ] + ), + ] + ) + TESTS.append( + [ + {"device": device, "rotate_range": (1.0, 1.0, 1.0), "shear_range": (0.1,), "scale_range": (1.2,)}, + {"grid": p(torch.arange(0, 108).reshape((4, 3, 3, 3)))}, + p( + np.array( [ - [9.9000e01, 1.0000e02, 1.0100e02], - [1.0200e02, 1.0300e02, 1.0400e02], - [1.0500e02, 1.0600e02, 1.0700e02], - ], - ], - ] - ) - ), - ], -] + [ + [ + [-9.4201e00, -8.1672e00, -6.9143e00], + [-5.6614e00, -4.4085e00, -3.1556e00], + [-1.9027e00, -6.4980e-01, 6.0310e-01], + ], + [ + [1.8560e00, 3.1089e00, 4.3618e00], + [5.6147e00, 6.8676e00, 8.1205e00], + [9.3734e00, 1.0626e01, 1.1879e01], + ], + [ + [1.3132e01, 1.4385e01, 1.5638e01], + [1.6891e01, 1.8144e01, 1.9397e01], + [2.0650e01, 2.1902e01, 2.3155e01], + ], + ], + [ + [ + [9.9383e-02, -4.8845e-01, -1.0763e00], + [-1.6641e00, -2.2519e00, -2.8398e00], + [-3.4276e00, -4.0154e00, -4.6032e00], + ], + [ + [-5.1911e00, -5.7789e00, -6.3667e00], + [-6.9546e00, -7.5424e00, -8.1302e00], + [-8.7180e00, -9.3059e00, -9.8937e00], + ], + [ + [-1.0482e01, -1.1069e01, -1.1657e01], + [-1.2245e01, -1.2833e01, -1.3421e01], + [-1.4009e01, -1.4596e01, -1.5184e01], + ], + ], + [ + [ + [5.9635e01, 6.1199e01, 6.2764e01], + [6.4328e01, 6.5892e01, 6.7456e01], + [6.9021e01, 7.0585e01, 7.2149e01], + ], + [ + [7.3714e01, 7.5278e01, 7.6842e01], + [7.8407e01, 7.9971e01, 8.1535e01], + [8.3099e01, 8.4664e01, 8.6228e01], + ], + [ + [8.7792e01, 8.9357e01, 9.0921e01], + [9.2485e01, 9.4049e01, 9.5614e01], + [9.7178e01, 9.8742e01, 1.0031e02], + ], + ], + [ + [ + [8.1000e01, 8.2000e01, 8.3000e01], + [8.4000e01, 8.5000e01, 8.6000e01], + [8.7000e01, 8.8000e01, 8.9000e01], + ], + [ + [9.0000e01, 9.1000e01, 9.2000e01], + [9.3000e01, 9.4000e01, 9.5000e01], + [9.6000e01, 9.7000e01, 9.8000e01], + ], + [ + [9.9000e01, 1.0000e02, 1.0100e02], + [1.0200e02, 1.0300e02, 1.0400e02], + [1.0500e02, 1.0600e02, 1.0700e02], + ], + ], + ] + ) + ), + ] + ) class TestRandAffineGrid(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affine_grid(self, input_param, input_data, expected_val): g = RandAffineGrid(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) - if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + if "grid" in input_data: + self.assertEqual(type(result), type(expected_val)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, expected_val.device) + result = result.cpu() + expected_val = expected_val.cpu() + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index d2f8a60665..a0078aae1e 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -17,179 +17,188 @@ from monai.transforms import RandAffined from monai.utils import GridSampleMode +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - dict(as_tensor_output=False, device=None, spatial_size=None, keys=("img", "seg")), - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=(2, 2), keys=("img", "seg")), - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - dict(as_tensor_output=False, device=None, spatial_size=(2, 2), cache_grid=True, keys=("img", "seg")), - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - dict(as_tensor_output=True, device=None, spatial_size=(2, 2, 2), keys=("img", "seg")), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.ones((1, 2, 2, 2)), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - keys=("img", "seg"), - mode="bilinear", - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=False, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - cache_grid=True, - keys=("img", "seg"), - mode="bilinear", - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - np.array([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=True, - spatial_size=(3, 3), - keys=("img", "seg"), - device=None, - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]), - ], - [ - dict( - prob=0.9, - mode=("bilinear", "nearest"), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], - [ - dict( - prob=0.9, - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - as_tensor_output=True, - spatial_size=(2, 2, 2), - padding_mode="zeros", - device=None, - keys=("img", "seg"), - mode=GridSampleMode.BILINEAR, - ), - {"img": torch.ones((1, 3, 3, 3)), "seg": torch.ones((1, 3, 3, 3))}, - torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]), - ], - [ - dict( - prob=0.9, - mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], - [ - dict( - prob=0.9, - mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - rotate_range=(np.pi / 2,), - shear_range=[1, 2], - translate_range=[2, 1], - scale_range=[0.1, 0.2], - as_tensor_output=False, - spatial_size=(3, 3), - cache_grid=True, - keys=("img", "seg"), - device=torch.device("cpu:0"), - ), - {"img": torch.arange(64).reshape((1, 8, 8)), "seg": torch.arange(64).reshape((1, 8, 8))}, - { - "img": np.array( - [ - [ - [18.736153, 15.581954, 12.4277525], - [27.398798, 24.244598, 21.090399], - [36.061443, 32.90724, 29.753046], - ] - ] - ), - "seg": np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]]), - }, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + dict(device=device, spatial_size=None, keys=("img", "seg")), + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2), keys=("img", "seg")), + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2), cache_grid=True, keys=("img", "seg")), + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), + ] + ) + TESTS.append( + [ + dict(device=device, spatial_size=(2, 2, 2), keys=("img", "seg")), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.ones((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + spatial_size=(2, 2, 2), + padding_mode="zeros", + device=device, + keys=("img", "seg"), + mode="bilinear", + ), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + p( + torch.tensor( + [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]] + ) + ), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=("bilinear", "nearest"), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + spatial_size=(2, 2, 2), + padding_mode="zeros", + device=device, + keys=("img", "seg"), + mode=GridSampleMode.BILINEAR, + ), + {"img": p(torch.ones((1, 3, 3, 3))), "seg": p(torch.ones((1, 3, 3, 3)))}, + p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])), + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) + TESTS.append( + [ + dict( + prob=0.9, + mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + rotate_range=(np.pi / 2,), + shear_range=[1, 2], + translate_range=[2, 1], + scale_range=[0.1, 0.2], + spatial_size=(3, 3), + cache_grid=True, + keys=("img", "seg"), + device=device, + ), + {"img": p(torch.arange(64).reshape((1, 8, 8))), "seg": p(torch.arange(64).reshape((1, 8, 8)))}, + { + "img": p( + np.array( + [ + [ + [18.736153, 15.581954, 12.4277525], + [27.398798, 24.244598, 21.090399], + [36.061443, 32.90724, 29.753046], + ] + ] + ) + ), + "seg": p(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])), + }, + ] + ) class TestRandAffined(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_affined(self, input_param, input_data, expected_val): g = RandAffined(**input_param).set_random_state(123) res = g(input_data) @@ -200,23 +209,20 @@ def test_rand_affined(self, input_param, input_data, expected_val): if "_transforms" in key: continue expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) + self.assertEqual(type(result), type(expected)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected.device) + result, expected = result.cpu(), expected.cpu() + np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) def test_ill_cache(self): with self.assertWarns(UserWarning): # spatial size is None - RandAffined( - as_tensor_output=False, device=None, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg") - ) + RandAffined(device=device, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg")) with self.assertWarns(UserWarning): # spatial size is dynamic RandAffined( - as_tensor_output=False, - device=None, + device=device, spatial_size=(2, -1), prob=1.0, cache_grid=True, diff --git a/tests/test_rand_elastic_2d.py b/tests/test_rand_elastic_2d.py index fbfb7d5761..830d0f39e9 100644 --- a/tests/test_rand_elastic_2d.py +++ b/tests/test_rand_elastic_2d.py @@ -16,90 +16,105 @@ from parameterized import parameterized from monai.transforms import Rand2DElastic +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "as_tensor_output": False, "device": None}, - {"img": torch.ones((3, 3, 3)), "spatial_size": (2, 2)}, - np.ones((3, 2, 2)), - ], - [ - {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "as_tensor_output": False, "device": None}, - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.arange(27).reshape((3, 3, 3)), - ], - [ - { - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "padding_mode": "zeros", - }, - {"img": torch.ones((3, 3, 3)), "spatial_size": (2, 2), "mode": "bilinear"}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], + {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "device": device}, + {"img": p(torch.ones((3, 3, 3))), "spatial_size": (2, 2)}, + p(np.ones((3, 2, 2))), ] - ), - ], - [ - { - "spacing": (1.0, 1.0), - "magnitude_range": (1.0, 1.0), - "scale_range": [1.2, 2.2], - "prob": 0.9, - "padding_mode": "border", - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3))}, - torch.tensor( + ) + TESTS.append( [ - [[3.0793, 2.6141], [4.0568, 5.9978]], - [[12.0793, 11.6141], [13.0568, 14.9978]], - [[21.0793, 20.6141], [22.0568, 23.9978]], + {"spacing": (0.3, 0.3), "magnitude_range": (1.0, 2.0), "prob": 0.0, "device": device}, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p(np.arange(27).reshape((3, 3, 3))), ] - ), - ], - [ - { - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": False, - "device": "cuda" if torch.cuda.is_available() else "cpu", - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3))}, - np.array( + ) + TESTS.append( [ - [[1.3584113, 1.9251312], [5.626623, 6.642721]], - [[10.358411, 10.925131], [14.626623, 15.642721]], - [[19.358412, 19.92513], [23.626623, 24.642721]], + { + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + "padding_mode": "zeros", + }, + {"img": p(torch.ones((3, 3, 3))), "spatial_size": (2, 2), "mode": "bilinear"}, + p( + np.array( + [ + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + ] + ) + ), ] - ), - ], -] + ) + TESTS.append( + [ + { + "spacing": (1.0, 1.0), + "magnitude_range": (1.0, 1.0), + "scale_range": [1.2, 2.2], + "prob": 0.9, + "padding_mode": "border", + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + torch.tensor( + [ + [[3.0793, 2.6141], [4.0568, 5.9978]], + [[12.0793, 11.6141], [13.0568, 14.9978]], + [[21.0793, 20.6141], [22.0568, 23.9978]], + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + np.array( + [ + [[1.3584113, 1.9251312], [5.626623, 6.642721]], + [[10.358411, 10.925131], [14.626623, 15.642721]], + [[19.358412, 19.92513], [23.626623, 24.642721]], + ] + ) + ), + ] + ) class TestRand2DElastic(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_2d_elastic(self, input_param, input_data, expected_val): g = Rand2DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) + self.assertEqual(type(result), type(expected_val)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected_val.device) + result, expected_val = result.cpu(), expected_val.cpu() + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py index c63282d571..6add569723 100644 --- a/tests/test_rand_elastic_3d.py +++ b/tests/test_rand_elastic_3d.py @@ -16,69 +16,93 @@ from parameterized import parameterized from monai.transforms import Rand3DElastic +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - { - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(72).reshape((2, 3, 3, 4))}, - np.arange(72).reshape((2, 3, 3, 4)), - ], - [ - { - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - }, - {"img": torch.ones((2, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - np.ones((2, 2, 2, 2)), - ], - [ - { - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "spatial_size": (2, 2, 2)}, - np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]), - ], - [ - { - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": False, - "device": "cuda" if torch.cuda.is_available() else "cpu", - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "mode": "bilinear"}, - np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]), - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(72).reshape((2, 3, 3, 4)))}, + p(np.arange(72).reshape((2, 3, 3, 4))), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + }, + {"img": p(torch.ones((2, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p(np.ones((2, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "spatial_size": (2, 2, 2)}, + p( + np.array( + [ + [ + [[6.4939356, 7.50289], [9.518351, 10.522849]], + [[15.512375, 16.523542], [18.531467, 19.53646]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "mode": "bilinear"}, + p( + np.array( + [ + [ + [[5.0069294, 9.463932], [9.287769, 13.739735]], + [[12.319424, 16.777205], [16.594296, 21.045748]], + ] + ] + ) + ), + ] + ) class TestRand3DElastic(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) g.set_random_state(123) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) + self.assertEqual(type(result), type(expected_val)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected_val.device) + result, expected_val = result.cpu(), expected_val.cpu() + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py index f8eb026088..5dd5dabda2 100644 --- a/tests/test_rand_elasticd_2d.py +++ b/tests/test_rand_elasticd_2d.py @@ -16,127 +16,147 @@ from parameterized import parameterized from monai.transforms import Rand2DElasticd +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.ones((3, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.3, 0.3), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(4).reshape((1, 2, 2)), "seg": torch.arange(4).reshape((1, 2, 2))}, - np.arange(4).reshape((1, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "padding_mode": "zeros", - "device": None, - "spatial_size": (2, 2), - "mode": "bilinear", - }, - {"img": torch.ones((3, 3, 3)), "seg": torch.ones((3, 3, 3))}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( [ - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], - [[0.45531988, 0.0], [0.0, 0.71558857]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p(np.ones((3, 2, 2))), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (1.0, 1.0), - "magnitude_range": (1.0, 1.0), - "scale_range": [1.2, 2.2], - "prob": 0.9, - "padding_mode": "border", - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - torch.tensor( + ) + TESTS.append( [ - [[3.0793, 2.6141], [4.0568, 5.9978]], - [[12.0793, 11.6141], [13.0568, 14.9978]], - [[21.0793, 20.6141], [22.0568, 23.9978]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.3, 0.3), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(4).reshape((1, 2, 2))), "seg": p(torch.arange(4).reshape((1, 2, 2)))}, + p(np.arange(4).reshape((1, 2, 2))), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - np.array( + ) + TESTS.append( [ - [[1.3584113, 1.9251312], [5.626623, 6.642721]], - [[10.358411, 10.925131], [14.626623, 15.642721]], - [[19.358412, 19.92513], [23.626623, 24.642721]], + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (1.0, 2.0), + "prob": 0.9, + "padding_mode": "zeros", + "device": device, + "spatial_size": (2, 2), + "mode": "bilinear", + }, + {"img": p(torch.ones((3, 3, 3))), "seg": p(torch.ones((3, 3, 3)))}, + p( + np.array( + [ + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + [[0.45531988, 0.0], [0.0, 0.71558857]], + ] + ) + ), ] - ), - ], - [ - { - "keys": ("img", "seg"), - "mode": ("bilinear", "nearest"), - "spacing": (0.3, 0.3), - "magnitude_range": (0.1, 0.2), - "translate_range": [-0.01, 0.01], - "scale_range": [0.01, 0.02], - "prob": 0.9, - "as_tensor_output": True, - "device": None, - "spatial_size": (2, 2), - }, - {"img": torch.arange(27).reshape((3, 3, 3)), "seg": torch.arange(27).reshape((3, 3, 3))}, - { - "img": torch.tensor( - [ - [[1.3584, 1.9251], [5.6266, 6.6427]], - [[10.3584, 10.9251], [14.6266, 15.6427]], - [[19.3584, 19.9251], [23.6266, 24.6427]], - ] - ), - "seg": torch.tensor([[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]]), - }, - ], -] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (1.0, 1.0), + "magnitude_range": (1.0, 1.0), + "scale_range": [1.2, 2.2], + "prob": 0.9, + "padding_mode": "border", + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + torch.tensor( + [ + [[3.0793, 2.6141], [4.0568, 5.9978]], + [[12.0793, 11.6141], [13.0568, 14.9978]], + [[21.0793, 20.6141], [22.0568, 23.9978]], + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + p( + np.array( + [ + [[1.3584113, 1.9251312], [5.626623, 6.642721]], + [[10.358411, 10.925131], [14.626623, 15.642721]], + [[19.358412, 19.92513], [23.626623, 24.642721]], + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "mode": ("bilinear", "nearest"), + "spacing": (0.3, 0.3), + "magnitude_range": (0.1, 0.2), + "translate_range": [-0.01, 0.01], + "scale_range": [0.01, 0.02], + "prob": 0.9, + "device": device, + "spatial_size": (2, 2), + }, + {"img": p(torch.arange(27).reshape((3, 3, 3))), "seg": p(torch.arange(27).reshape((3, 3, 3)))}, + { + "img": p( + torch.tensor( + [ + [[1.3584, 1.9251], [5.6266, 6.6427]], + [[10.3584, 10.9251], [14.6266, 15.6427]], + [[19.3584, 19.9251], [23.6266, 24.6427]], + ] + ) + ), + "seg": p( + torch.tensor( + [[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]] + ) + ), + }, + ] + ) class TestRand2DElasticd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_2d_elasticd(self, input_param, input_data, expected_val): g = Rand2DElasticd(**input_param) g.set_random_state(123) @@ -144,11 +164,11 @@ def test_rand_2d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) + self.assertEqual(type(result), type(expected)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected.device) + result, expected = result.cpu(), expected.cpu() + np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py index 47ab814882..dfa960d374 100644 --- a/tests/test_rand_elasticd_3d.py +++ b/tests/test_rand_elasticd_3d.py @@ -16,98 +16,128 @@ from parameterized import parameterized from monai.transforms import Rand3DElasticd +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - }, - {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, - np.ones((2, 2, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, -1, -1), - }, - {"img": torch.ones((2, 3, 3, 3)), "seg": torch.ones((2, 3, 3, 3))}, - np.ones((2, 2, 3, 3)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 2.3), - "sigma_range": (1.0, 20.0), - "prob": 0.0, - "as_tensor_output": False, - "device": None, - "spatial_size": -1, - }, - {"img": torch.arange(8).reshape((1, 2, 2, 2)), "seg": torch.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - np.array([[[[6.4939356, 7.50289], [9.518351, 10.522849]], [[15.512375, 16.523542], [18.531467, 19.53646]]]]), - ], - [ - { - "keys": ("img", "seg"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": False, - "device": None, - "spatial_size": (2, 2, 2), - "mode": "bilinear", - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - np.array([[[[5.0069294, 9.463932], [9.287769, 13.739735]], [[12.319424, 16.777205], [16.594296, 21.045748]]]]), - ], - [ - { - "keys": ("img", "seg"), - "mode": ("bilinear", "nearest"), - "magnitude_range": (0.3, 0.3), - "sigma_range": (1.0, 2.0), - "prob": 0.9, - "rotate_range": [1, 1, 1], - "as_tensor_output": True, - "device": torch.device("cpu:0"), - "spatial_size": (2, 2, 2), - }, - {"img": torch.arange(27).reshape((1, 3, 3, 3)), "seg": torch.arange(27).reshape((1, 3, 3, 3))}, - { - "img": torch.tensor([[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]]), - "seg": torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]]), - }, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.ones((2, 3, 3, 3))), "seg": p(torch.ones((2, 3, 3, 3)))}, + p(np.ones((2, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": (2, -1, -1), + }, + {"img": p(torch.ones((2, 3, 3, 3))), "seg": p(torch.ones((2, 3, 3, 3)))}, + p(np.ones((2, 2, 3, 3))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 2.3), + "sigma_range": (1.0, 20.0), + "prob": 0.0, + "device": device, + "spatial_size": -1, + }, + {"img": p(torch.arange(8).reshape((1, 2, 2, 2))), "seg": p(torch.arange(8).reshape((1, 2, 2, 2)))}, + p(np.arange(8).reshape((1, 2, 2, 2))), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + p( + np.array( + [ + [ + [[6.4939356, 7.50289], [9.518351, 10.522849]], + [[15.512375, 16.523542], [18.531467, 19.53646]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + "mode": "bilinear", + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + p( + np.array( + [ + [ + [[5.0069294, 9.463932], [9.287769, 13.739735]], + [[12.319424, 16.777205], [16.594296, 21.045748]], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + { + "keys": ("img", "seg"), + "mode": ("bilinear", "nearest"), + "magnitude_range": (0.3, 0.3), + "sigma_range": (1.0, 2.0), + "prob": 0.9, + "rotate_range": [1, 1, 1], + "device": device, + "spatial_size": (2, 2, 2), + }, + {"img": p(torch.arange(27).reshape((1, 3, 3, 3))), "seg": p(torch.arange(27).reshape((1, 3, 3, 3)))}, + { + "img": p( + torch.tensor( + [[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]] + ) + ), + "seg": p(torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]])), + }, + ] + ) class TestRand3DElasticd(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_3d_elasticd(self, input_param, input_data, expected_val): g = Rand3DElasticd(**input_param) g.set_random_state(123) @@ -115,11 +145,11 @@ def test_rand_3d_elasticd(self, input_param, input_data, expected_val): for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected, torch.Tensor)) + self.assertEqual(type(result), type(expected)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected.device) + result, expected = result.cpu(), expected.cpu() + np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) if __name__ == "__main__": diff --git a/tests/test_resampler.py b/tests/test_resampler.py index 2be94acebd..7eba61def6 100644 --- a/tests/test_resampler.py +++ b/tests/test_resampler.py @@ -17,69 +17,128 @@ from monai.transforms import Resample from monai.transforms.utils import create_grid +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((2, 2)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]), - ], - [ - dict(padding_mode="border", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2))}, - np.array([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]]), - ], - [ - dict(padding_mode="reflection", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4)), "img": np.arange(4).reshape((1, 2, 2)), "mode": "nearest"}, - np.array([[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]]), - ], - [ - dict(padding_mode="zeros", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4, 4)), "img": np.arange(8).reshape((1, 2, 2, 2)), "mode": "bilinear"}, - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( [ - [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 4.0, 5.0, 0.0], [0.0, 6.0, 7.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - ] + dict(padding_mode="zeros", device=None), + {"grid": p(create_grid((2, 2))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q(np.array([[[0.0, 1.0], [2.0, 3.0]]])), ] - ), - ], - [ - dict(padding_mode="border", as_tensor_output=False, device=None), - {"grid": create_grid((4, 4, 4)), "img": np.arange(8).reshape((1, 2, 2, 2)), "mode": "bilinear"}, - np.array( + ) + TESTS.append( [ - [ - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3.0, 3.0], [2.0, 2.0, 3.0, 3.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], - [[4.0, 4.0, 5.0, 5.0], [4.0, 4.0, 5.0, 5.0], [6.0, 6.0, 7.0, 7.0], [6.0, 6.0, 7.0, 7.0]], - ] + dict(padding_mode="zeros", device=None), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), ] - ), - ], -] + ) + TESTS.append( + [ + dict(padding_mode="border", device=None), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2)))}, + q(np.array([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]])), + ] + ) + TESTS.append( + [ + dict(padding_mode="reflection", device=None), + {"grid": p(create_grid((4, 4))), "img": q(np.arange(4).reshape((1, 2, 2))), "mode": "nearest"}, + q(np.array([[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]])), + ] + ) + TESTS.append( + [ + dict(padding_mode="zeros", device=None), + {"grid": p(create_grid((4, 4, 4))), "img": q(np.arange(8).reshape((1, 2, 2, 2))), "mode": "bilinear"}, + q( + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 2.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 4.0, 5.0, 0.0], + [0.0, 6.0, 7.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ] + ] + ) + ), + ] + ) + TESTS.append( + [ + dict(padding_mode="border", device=None), + {"grid": p(create_grid((4, 4, 4))), "img": q(np.arange(8).reshape((1, 2, 2, 2))), "mode": "bilinear"}, + q( + np.array( + [ + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [2.0, 2.0, 3.0, 3.0], + [2.0, 2.0, 3.0, 3.0], + ], + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [2.0, 2.0, 3.0, 3.0], + [2.0, 2.0, 3.0, 3.0], + ], + [ + [4.0, 4.0, 5.0, 5.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + [6.0, 6.0, 7.0, 7.0], + ], + [ + [4.0, 4.0, 5.0, 5.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + [6.0, 6.0, 7.0, 7.0], + ], + ] + ] + ) + ), + ] + ) class TestResample(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_resample(self, input_param, input_data, expected_val): g = Resample(**input_param) result = g(**input_data) - self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) + self.assertEqual(type(result), type(expected_val)) if isinstance(result, torch.Tensor): - np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) - else: - np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + self.assertEqual(result.device, expected_val.device) + result = result.cpu() + expected_val = expected_val.cpu() + np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) if __name__ == "__main__": From d71a91d9f0ed66c4b3eafcde10e70c32a2ba6272 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 10:21:48 +0100 Subject: [PATCH 112/176] numpy/pytorch compatibility Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index a39b0526bb..be5629839a 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -759,7 +759,7 @@ def __call__(self, data: DataObjects.Mapping) -> Dict[Hashable, DataObjects.Imag if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) if self._do_transform: # add some random factors - grid = self.rand_affine.rand_affine_grid(grid=grid).clone() # type: ignore + grid = deepcopy(self.rand_affine.rand_affine_grid(grid=grid)) affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() # type: ignore[assignment] for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): From 9816ffe4b5f0216f830209194fdc44d4992f0509 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 11:17:46 +0100 Subject: [PATCH 113/176] resize, resized, png_writer Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/png_writer.py | 44 ++++++++++++++++++------------- monai/transforms/spatial/array.py | 25 ++++++++++-------- tests/test_png_rw.py | 37 +++++++++++++++++--------- tests/test_resize.py | 21 ++++++++++----- tests/test_resized.py | 19 ++++++++++--- 5 files changed, 94 insertions(+), 52 deletions(-) diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index 9ce01ed97f..3ab921bf73 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -12,15 +12,18 @@ from typing import Optional, Sequence, Union import numpy as np +import torch from monai.transforms.spatial.array import Resize from monai.utils import InterpolateMode, ensure_tuple_rep, optional_import +from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type Image, _ = optional_import("PIL", name="Image") def write_png( - data: np.ndarray, + data: DataObjects.Images, file_name: str, output_spatial_shape: Optional[Sequence[int]] = None, mode: Union[InterpolateMode, str] = InterpolateMode.BICUBIC, @@ -47,39 +50,42 @@ def write_png( ValueError: When ``scale`` is not one of [255, 65535]. """ - if not isinstance(data, np.ndarray): - raise AssertionError("input data must be numpy array.") - if len(data.shape) == 3 and data.shape[2] == 1: # PIL Image can't save image with 1 channel - data = data.squeeze(2) + if not isinstance(data, (np.ndarray, torch.Tensor)): + raise AssertionError("input data must be np.ndarray/torch.Tensor.") + data_np: np.ndarray + data_np, *_ = convert_data_type(data, np.ndarray) # type: ignore + if len(data_np.shape) == 3 and data_np.shape[2] == 1: # PIL Image can't save image with 1 channel + data_np = data_np.squeeze(2) if output_spatial_shape is not None: output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2) mode = InterpolateMode(mode) align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners) - _min, _max = np.min(data), np.max(data) - if len(data.shape) == 3: - data = np.moveaxis(data, -1, 0) # to channel first - data = xform(data) - data = np.moveaxis(data, 0, -1) + _min, _max = np.min(data_np), np.max(data_np) + if len(data_np.shape) == 3: + data_np = np.moveaxis(data_np, -1, 0) # to channel first + data_np = xform(data_np) # type: ignore + data_np = np.moveaxis(data_np, 0, -1) else: # (H, W) - data = np.expand_dims(data, 0) # make a channel - data = xform(data)[0] # first channel + data_np = np.expand_dims(data_np, 0) # make a channel + # first channel + data_np = xform(data_np)[0] # type: ignore if mode != InterpolateMode.NEAREST: - data = np.clip(data, _min, _max) # type: ignore + data_np = np.clip(data_np, _min, _max) # type: ignore if scale is not None: - data = np.clip(data, 0.0, 1.0) # type: ignore # png writer only can scale data in range [0, 1] + data_np = np.clip(data_np, 0.0, 1.0) # type: ignore # png writer only can scale data in range [0, 1] if scale == np.iinfo(np.uint8).max: - data = (scale * data).astype(np.uint8) + data_np = (scale * data_np).astype(np.uint8) elif scale == np.iinfo(np.uint16).max: - data = (scale * data).astype(np.uint16) + data_np = (scale * data_np).astype(np.uint16) else: raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") # PNG data must be int number - if data.dtype not in (np.uint8, np.uint16): # type: ignore - data = data.astype(np.uint8) + if data_np.dtype not in (np.uint8, np.uint16): # type: ignore + data_np = data_np.astype(np.uint8) - img = Image.fromarray(data) + img = Image.fromarray(data_np) img.save(file_name, "PNG") return diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 273281ce1b..1807cfa414 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -347,7 +347,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return self.post_convert_data(result, orig_type, orig_device) -class Resize(ToDoTransform): +class Resize(TorchTransform): """ Resize the input image to given spatial size (with scaling, not cropping/padding). Implemented using :py:class:`torch.nn.functional.interpolate`. @@ -377,10 +377,10 @@ def __init__( def __call__( self, - img: np.ndarray, + img: DataObjects.Images, mode: Optional[Union[InterpolateMode, str]] = None, align_corners: Optional[bool] = None, - ) -> np.ndarray: + ) -> DataObjects.Images: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). @@ -395,25 +395,28 @@ def __call__( ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. """ - input_ndim = img.ndim - 1 # spatial ndim + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=float) # type: ignore + input_ndim = img_t.ndim - 1 # spatial ndim output_ndim = len(self.spatial_size) if output_ndim > input_ndim: - input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) - img = img.reshape(input_shape) + input_shape = ensure_tuple_size(img_t.shape, output_ndim + 1, 1) + img_t = img_t.reshape(input_shape) elif output_ndim < input_ndim: raise ValueError( "len(spatial_size) must be greater or equal to img spatial dimensions, " f"got spatial_size={output_ndim} img={input_ndim}." ) - spatial_size = fall_back_tuple(self.spatial_size, img.shape[1:]) - resized = torch.nn.functional.interpolate( # type: ignore - input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0), + spatial_size = fall_back_tuple(self.spatial_size, img_t.shape[1:]) + resized = torch.nn.functional.interpolate( + input=img_t.unsqueeze(0), size=spatial_size, mode=self.mode.value if mode is None else InterpolateMode(mode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) - resized = resized.squeeze(0).detach().cpu().numpy() - return np.asarray(resized) + resized = resized.squeeze(0) + out, *_ = convert_data_type(resized, orig_type, orig_device) + return out class Rotate(TorchTransform, ThreadUnsafe): diff --git a/tests/test_png_rw.py b/tests/test_png_rw.py index 815d0bcf2c..b4dc5f4e49 100644 --- a/tests/test_png_rw.py +++ b/tests/test_png_rw.py @@ -14,62 +14,75 @@ import unittest import numpy as np +from parameterized import parameterized from PIL import Image from monai.data import write_png +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append((p,)) +# TESTS = [[np.array]] class TestPngWrite(unittest.TestCase): - def test_write_gray(self): + @parameterized.expand(TESTS) + def test_write_gray(self, in_type): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3) img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) + write_png(in_type(img), image_name, scale=255) out = np.asarray(Image.open(image_name)) np.testing.assert_allclose(out, img_save_val) - def test_write_gray_1height(self): + @parameterized.expand(TESTS) + def test_write_gray_1height(self, in_type): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(1, 3) img_save_val = (65535 * img).astype(np.uint16) - write_png(img, image_name, scale=65535) + write_png(in_type(img), image_name, scale=65535) out = np.asarray(Image.open(image_name)) np.testing.assert_allclose(out, img_save_val) - def test_write_gray_1channel(self): + @parameterized.expand(TESTS) + def test_write_gray_1channel(self, in_type): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3, 1) img_save_val = (255 * img).astype(np.uint8).squeeze(2) - write_png(img, image_name, scale=255) + write_png(in_type(img), image_name, scale=255) out = np.asarray(Image.open(image_name)) np.testing.assert_allclose(out, img_save_val) - def test_write_rgb(self): + @parameterized.expand(TESTS) + def test_write_rgb(self, in_type): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3, 3) img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) + write_png(in_type(img), image_name, scale=255) out = np.asarray(Image.open(image_name)) np.testing.assert_allclose(out, img_save_val) - def test_write_2channels(self): + @parameterized.expand(TESTS) + def test_write_2channels(self, in_type): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 3, 2) img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) + write_png(in_type(img), image_name, scale=255) out = np.asarray(Image.open(image_name)) np.testing.assert_allclose(out, img_save_val) - def test_write_output_shape(self): + @parameterized.expand(TESTS) + def test_write_output_shape(self, in_type): with tempfile.TemporaryDirectory() as out_dir: image_name = os.path.join(out_dir, "test.png") img = np.random.rand(2, 2, 3) - write_png(img, image_name, (4, 4), scale=255) + write_png(in_type(img), image_name, (4, 4), scale=255) out = np.asarray(Image.open(image_name)) np.testing.assert_allclose(out.shape, (4, 4, 3)) diff --git a/tests/test_resize.py b/tests/test_resize.py index 22a68bcf85..56353f2bea 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -10,13 +10,22 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import skimage.transform +import torch from parameterized import parameterized from monai.transforms import Resize -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D + +TESTS: List[Tuple] = [] +for p in TEST_NDARRAYS: + TESTS.append((p, (32, -1), "area")) + TESTS.append((p, (32, 32), "area")) + TESTS.append((p, (32, 32, 32), "trilinear")) + TESTS.append((p, (256, 256), "bilinear")) class TestResize(NumpyImageTestCase2D): @@ -29,10 +38,8 @@ def test_invalid_inputs(self): resize = Resize(spatial_size=(128,), mode="order") resize(self.imt[0]) - @parameterized.expand( - [((32, -1), "area"), ((32, 32), "area"), ((32, 32, 32), "trilinear"), ((256, 256), "bilinear")] - ) - def test_correct_results(self, spatial_size, mode): + @parameterized.expand(TESTS) + def test_correct_results(self, in_type, spatial_size, mode): resize = Resize(spatial_size, mode=mode) _order = 0 if mode.endswith("linear"): @@ -47,7 +54,9 @@ def test_correct_results(self, spatial_size, mode): ) ) expected = np.stack(expected).astype(np.float32) - out = resize(self.imt[0]) + out = resize(in_type(self.imt[0])) + if isinstance(out, torch.Tensor): + out = out.cpu() np.testing.assert_allclose(out, expected, atol=0.9) diff --git a/tests/test_resized.py b/tests/test_resized.py index d89c866af3..29fab0b5c9 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -10,13 +10,22 @@ # limitations under the License. import unittest +from typing import List, Tuple import numpy as np import skimage.transform +import torch from parameterized import parameterized from monai.transforms import Resized -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D + +TESTS: List[Tuple] = [] +for p in TEST_NDARRAYS: + TESTS.append((p, (32, -1), "area")) + TESTS.append((p, (64, 64), "area")) + TESTS.append((p, (32, 32, 32), "area")) + TESTS.append((p, (256, 256), "bilinear")) class TestResized(NumpyImageTestCase2D): @@ -29,8 +38,8 @@ def test_invalid_inputs(self): resize = Resized(keys="img", spatial_size=(128,), mode="order") resize({"img": self.imt[0]}) - @parameterized.expand([((32, -1), "area"), ((64, 64), "area"), ((32, 32, 32), "area"), ((256, 256), "bilinear")]) - def test_correct_results(self, spatial_size, mode): + @parameterized.expand(TESTS) + def test_correct_results(self, in_type, spatial_size, mode): resize = Resized("img", spatial_size, mode) _order = 0 if mode.endswith("linear"): @@ -45,7 +54,9 @@ def test_correct_results(self, spatial_size, mode): ) ) expected = np.stack(expected).astype(np.float32) - out = resize({"img": self.imt[0]})["img"] + out = resize({"img": in_type(self.imt[0])})["img"] + if isinstance(out, torch.Tensor): + out = out.cpu() np.testing.assert_allclose(out, expected, atol=0.9) From 8d5446cd80cc7985799d4b949a2d94054579c91e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 12:02:35 +0100 Subject: [PATCH 114/176] orientation Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 22 ++- monai/utils/misc.py | 2 + tests/test_orientation.py | 238 ++++++++++++++++++------------ tests/test_orientationd.py | 61 +++++--- 4 files changed, 198 insertions(+), 125 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 1807cfa414..2d51353afe 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -25,6 +25,7 @@ from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.croppad.array import CenterSpatialCrop, Pad from monai.transforms.transform import ( + NumpyTransform, Randomizable, RandomizableTransform, ThreadUnsafe, @@ -231,7 +232,7 @@ def __call__( return output_data, affine, new_affine -class Orientation(ToDoTransform): +class Orientation(NumpyTransform): """ Change the input image's orientation into the specified based on `axcodes`. """ @@ -269,8 +270,8 @@ def __init__( self.labels = labels def __call__( - self, data_array: np.ndarray, affine: Optional[np.ndarray] = None - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + self, data_array: DataObjects.Images, affine: Optional[DataObjects.Images] = None + ) -> Tuple[DataObjects.Images, DataObjects.Images, np.ndarray]: """ original orientation of `data_array` is defined by `affine`. @@ -286,13 +287,16 @@ def __call__( data_array (reoriented in `self.axcodes`), original axcodes, current axcodes. """ - sr = data_array.ndim - 1 + data_np: np.ndarray + data_np, orig_type, orig_device = convert_data_type(data_array, np.ndarray) # type: ignore + sr = data_np.ndim - 1 if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") if affine is None: affine = np.eye(sr + 1, dtype=np.float64) affine_ = np.eye(sr + 1, dtype=np.float64) else: + affine, *_ = convert_data_type(affine, np.ndarray) affine_ = to_affine_nd(sr, affine) # type: ignore src = nib.io_orientation(affine_) if self.as_closest_canonical: @@ -309,12 +313,14 @@ def __call__( ornt = spatial_ornt.copy() ornt[:, 0] += 1 # skip channel dim ornt = np.concatenate([np.array([[0, 1]]), ornt]) - shape = data_array.shape[1:] - data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) + shape = data_np.shape[1:] + data_np = np.ascontiguousarray(nib.orientations.apply_orientation(data_np, ornt)) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) new_affine = to_affine_nd(affine, new_affine) - return data_array, affine, new_affine + data_out, *_ = convert_data_type(data_np, orig_type, orig_device) + + return data_out, affine, new_affine class Flip(TorchTransform): @@ -1045,7 +1051,7 @@ def __call__( return grid, affine -class RandAffineGrid(Randomizable, ToDoTransform): +class RandAffineGrid(Randomizable, TorchTransform): """ Generate randomised affine grid. """ diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 2a5a8bc0b3..3d8ba2a4e3 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -385,6 +385,8 @@ def get_dtype(data: Any): if output_type is torch.Tensor: if orig_type is np.ndarray: + if (np.array(data.strides) < 0).any(): # copy if -ve stride + data = data.copy() data = torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data)) else: data = torch.as_tensor(data) diff --git a/tests/test_orientation.py b/tests/test_orientation.py index aa7f33a469..875ad078bc 100644 --- a/tests/test_orientation.py +++ b/tests/test_orientation.py @@ -13,112 +13,164 @@ import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.transforms import Orientation, create_rotate, create_translate +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"axcodes": "RAS"}, - np.arange(12).reshape((2, 1, 2, 3)), - {"affine": np.eye(4)}, - np.arange(12).reshape((2, 1, 2, 3)), - "RAS", - ], - [ - {"axcodes": "ALS"}, - np.arange(12).reshape((2, 1, 2, 3)), - {"affine": np.diag([-1, -1, 1, 1])}, - np.array([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]), - "ALS", - ], - [ - {"axcodes": "RAS"}, - np.arange(12).reshape((2, 1, 2, 3)), - {"affine": np.diag([-1, -1, 1, 1])}, - np.array([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]), - "RAS", - ], - [ - {"axcodes": "AL"}, - np.arange(6).reshape((2, 1, 3)), - {"affine": np.eye(3)}, - np.array([[[0], [1], [2]], [[3], [4], [5]]]), - "AL", - ], - [{"axcodes": "L"}, np.arange(6).reshape((2, 3)), {"affine": np.eye(2)}, np.array([[2, 1, 0], [5, 4, 3]]), "L"], - [{"axcodes": "L"}, np.arange(6).reshape((2, 3)), {"affine": np.eye(2)}, np.array([[2, 1, 0], [5, 4, 3]]), "L"], - [{"axcodes": "L"}, np.arange(6).reshape((2, 3)), {"affine": np.diag([-1, 1])}, np.arange(6).reshape((2, 3)), "L"], - [ - {"axcodes": "LPS"}, - np.arange(12).reshape((2, 1, 2, 3)), - { - "affine": create_translate(3, (10, 20, 30)) - @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4)) - @ np.diag([-1, 1, 1, 1]) - }, - np.array([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]), - "LPS", - ], - [ - {"as_closest_canonical": True}, - np.arange(12).reshape((2, 1, 2, 3)), - { - "affine": create_translate(3, (10, 20, 30)) - @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4)) - @ np.diag([-1, 1, 1, 1]) - }, - np.array([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]), - "RAS", - ], - [ - {"as_closest_canonical": True}, - np.arange(6).reshape((1, 2, 3)), - {"affine": create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])}, - np.array([[[3, 0], [4, 1], [5, 2]]]), - "RA", - ], - [ - {"axcodes": "LP"}, - np.arange(6).reshape((1, 2, 3)), - {"affine": create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])}, - np.array([[[2, 5], [1, 4], [0, 3]]]), - "LP", - ], - [ - {"axcodes": "LPID", "labels": tuple(zip("LPIC", "RASD"))}, - np.zeros((1, 2, 3, 4, 5)), - {"affine": np.diag([-1, -0.2, -1, 1, 1])}, - np.zeros((1, 2, 3, 4, 5)), - "LPID", - ], - [ - {"as_closest_canonical": True, "labels": tuple(zip("LPIC", "RASD"))}, - np.zeros((1, 2, 3, 4, 5)), - {"affine": np.diag([-1, -0.2, -1, 1, 1])}, - np.zeros((1, 2, 3, 4, 5)), - "RASD", - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + {"axcodes": "RAS"}, + p(np.arange(12).reshape((2, 1, 2, 3))), + {"affine": q(np.eye(4))}, + np.arange(12).reshape((2, 1, 2, 3)), + "RAS", + ] + ) + TESTS.append( + [ + {"axcodes": "ALS"}, + p(np.arange(12).reshape((2, 1, 2, 3))), + {"affine": q(np.diag([-1, -1, 1, 1]))}, + np.array([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]), + "ALS", + ] + ) + TESTS.append( + [ + {"axcodes": "RAS"}, + p(np.arange(12).reshape((2, 1, 2, 3))), + {"affine": q(np.diag([-1, -1, 1, 1]))}, + np.array([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]), + "RAS", + ] + ) + TESTS.append( + [ + {"axcodes": "AL"}, + p(np.arange(6).reshape((2, 1, 3))), + {"affine": q(np.eye(3))}, + np.array([[[0], [1], [2]], [[3], [4], [5]]]), + "AL", + ] + ) + TESTS.append( + [ + {"axcodes": "L"}, + p(np.arange(6).reshape((2, 3))), + {"affine": q(np.eye(2))}, + np.array([[2, 1, 0], [5, 4, 3]]), + "L", + ] + ) + TESTS.append( + [ + {"axcodes": "L"}, + p(np.arange(6).reshape((2, 3))), + {"affine": q(np.eye(2))}, + np.array([[2, 1, 0], [5, 4, 3]]), + "L", + ] + ) + TESTS.append( + [ + {"axcodes": "L"}, + p(np.arange(6).reshape((2, 3))), + {"affine": q(np.diag([-1, 1]))}, + np.arange(6).reshape((2, 3)), + "L", + ] + ) + TESTS.append( + [ + {"axcodes": "LPS"}, + p(np.arange(12).reshape((2, 1, 2, 3))), + { + "affine": q( + create_translate(3, (10, 20, 30)) + @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4)) + @ np.diag([-1, 1, 1, 1]) + ) + }, + np.array([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]), + "LPS", + ] + ) + TESTS.append( + [ + {"as_closest_canonical": True}, + np.arange(12).reshape((2, 1, 2, 3)), + { + "affine": q( + create_translate(3, (10, 20, 30)) + @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4)) + @ np.diag([-1, 1, 1, 1]) + ) + }, + np.array([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]), + "RAS", + ] + ) + TESTS.append( + [ + {"as_closest_canonical": True}, + p(np.arange(6).reshape((1, 2, 3))), + {"affine": q(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1]))}, + np.array([[[3, 0], [4, 1], [5, 2]]]), + "RA", + ] + ) + TESTS.append( + [ + {"axcodes": "LP"}, + p(np.arange(6).reshape((1, 2, 3))), + {"affine": q(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1]))}, + np.array([[[2, 5], [1, 4], [0, 3]]]), + "LP", + ] + ) + TESTS.append( + [ + {"axcodes": "LPID", "labels": tuple(zip("LPIC", "RASD"))}, + p(np.zeros((1, 2, 3, 4, 5))), + {"affine": q(np.diag([-1, -0.2, -1, 1, 1]))}, + np.zeros((1, 2, 3, 4, 5)), + "LPID", + ] + ) + TESTS.append( + [ + {"as_closest_canonical": True, "labels": tuple(zip("LPIC", "RASD"))}, + p(np.zeros((1, 2, 3, 4, 5))), + {"affine": q(np.diag([-1, -0.2, -1, 1, 1]))}, + np.zeros((1, 2, 3, 4, 5)), + "RASD", + ] + ) -ILL_CASES = [ - # no axcodes or as_cloest_canonical - [{}, np.arange(6).reshape((2, 3)), "L"], - # too short axcodes - [{"axcodes": "RA"}, np.arange(12).reshape((2, 1, 2, 3)), {"affine": np.eye(4)}], -] +ILL_CASES = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + # no axcodes or as_cloest_canonical + ILL_CASES.append([{}, p(np.arange(6).reshape((2, 3))), "L"]) + # too short axcodes + ILL_CASES.append([{"axcodes": "RA"}, p(np.arange(12).reshape((2, 1, 2, 3))), {"affine": q(np.eye(4))}]) class TestOrientationCase(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_ornt(self, init_param, img, data_param, expected_data, expected_code): ornt = Orientation(**init_param) res = ornt(img, **data_param) - if not isinstance(res, tuple): - np.testing.assert_allclose(res, expected_data) - return + res = [i.cpu() if isinstance(i, torch.Tensor) else i for i in res] np.testing.assert_allclose(res[0], expected_data) original_affine = data_param["affine"] + if isinstance(original_affine, torch.Tensor): + original_affine = original_affine.cpu() np.testing.assert_allclose(original_affine, res[1]) new_code = nib.orientations.aff2axcodes(res[2], labels=ornt.labels) self.assertEqual("".join(new_code), expected_code) diff --git a/tests/test_orientationd.py b/tests/test_orientationd.py index 452172ce9b..01e647decb 100644 --- a/tests/test_orientationd.py +++ b/tests/test_orientationd.py @@ -13,25 +13,34 @@ import nibabel as nib import numpy as np +from parameterized import parameterized from monai.transforms import Orientationd +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append((p, q)) class TestOrientationdCase(unittest.TestCase): - def test_orntd(self): - data = {"seg": np.ones((2, 1, 2, 3)), "seg_meta_dict": {"affine": np.eye(4)}} + @parameterized.expand(TESTS) + def test_orntd(self, im_type, affine_type): + data = {"seg": im_type(np.ones((2, 1, 2, 3))), "seg_meta_dict": {"affine": affine_type(np.eye(4))}} ornt = Orientationd(keys="seg", axcodes="RAS") res = ornt(data) np.testing.assert_allclose(res["seg"].shape, (2, 1, 2, 3)) code = nib.aff2axcodes(res["seg_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("R", "A", "S")) - def test_orntd_3d(self): + @parameterized.expand(TESTS) + def test_orntd_3d(self, im_type, affine_type): data = { - "seg": np.ones((2, 1, 2, 3)), - "img": np.ones((2, 1, 2, 3)), - "seg_meta_dict": {"affine": np.eye(4)}, - "img_meta_dict": {"affine": np.eye(4)}, + "seg": im_type(np.ones((2, 1, 2, 3))), + "img": im_type(np.ones((2, 1, 2, 3))), + "seg_meta_dict": {"affine": affine_type(np.eye(4))}, + "img_meta_dict": {"affine": affine_type(np.eye(4))}, } ornt = Orientationd(keys=("img", "seg"), axcodes="PLI") res = ornt(data) @@ -42,12 +51,13 @@ def test_orntd_3d(self): code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("P", "L", "I")) - def test_orntd_2d(self): + @parameterized.expand(TESTS) + def test_orntd_2d(self, im_type, affine_type): data = { - "seg": np.ones((2, 1, 3)), - "img": np.ones((2, 1, 3)), - "seg_meta_dict": {"affine": np.eye(4)}, - "img_meta_dict": {"affine": np.eye(4)}, + "seg": im_type(np.ones((2, 1, 3))), + "img": im_type(np.ones((2, 1, 3))), + "seg_meta_dict": {"affine": affine_type(np.eye(4))}, + "img_meta_dict": {"affine": affine_type(np.eye(4))}, } ornt = Orientationd(keys=("img", "seg"), axcodes="PLI") res = ornt(data) @@ -57,12 +67,13 @@ def test_orntd_2d(self): code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("P", "L", "S")) - def test_orntd_1d(self): + @parameterized.expand(TESTS) + def test_orntd_1d(self, im_type, affine_type): data = { - "seg": np.ones((2, 3)), - "img": np.ones((2, 3)), - "seg_meta_dict": {"affine": np.eye(4)}, - "img_meta_dict": {"affine": np.eye(4)}, + "seg": im_type(np.ones((2, 3))), + "img": im_type(np.ones((2, 3))), + "seg_meta_dict": {"affine": affine_type(np.eye(4))}, + "img_meta_dict": {"affine": affine_type(np.eye(4))}, } ornt = Orientationd(keys=("img", "seg"), axcodes="L") res = ornt(data) @@ -72,12 +83,13 @@ def test_orntd_1d(self): code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("L", "A", "S")) - def test_orntd_canonical(self): + @parameterized.expand(TESTS) + def test_orntd_canonical(self, im_type, affine_type): data = { - "seg": np.ones((2, 1, 2, 3)), - "img": np.ones((2, 1, 2, 3)), - "seg_meta_dict": {"affine": np.eye(4)}, - "img_meta_dict": {"affine": np.eye(4)}, + "seg": im_type(np.ones((2, 1, 2, 3))), + "img": im_type(np.ones((2, 1, 2, 3))), + "seg_meta_dict": {"affine": affine_type(np.eye(4))}, + "img_meta_dict": {"affine": affine_type(np.eye(4))}, } ornt = Orientationd(keys=("img", "seg"), as_closest_canonical=True) res = ornt(data) @@ -88,8 +100,9 @@ def test_orntd_canonical(self): code = nib.aff2axcodes(res["img_meta_dict"]["affine"], ornt.ornt_transform.labels) self.assertEqual(code, ("R", "A", "S")) - def test_orntd_no_metadata(self): - data = {"seg": np.ones((2, 1, 2, 3))} + @parameterized.expand(TESTS) + def test_orntd_no_metadata(self, im_type, _): + data = {"seg": im_type(np.ones((2, 1, 2, 3)))} ornt = Orientationd(keys="seg", axcodes="RAS") res = ornt(data) np.testing.assert_allclose(res["seg"].shape, (2, 1, 2, 3)) From a3e235da999ac7c8705b4a37420c0369e6d61bf1 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 12:03:29 +0100 Subject: [PATCH 115/176] TTA Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_testtimeaugmentation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index ab9c7c4c18..a7cbd49656 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -82,7 +82,6 @@ def test_test_time_augmentation(self): scale_range=((0.8, 1), (0.8, 1)), padding_mode="zeros", mode=("bilinear", "nearest"), - as_tensor_output=False, ), CropForegroundd(keys, source_key="image"), DivisiblePadd(keys, 4), From 5f0d6a59cc8246fc140e8df9b11b7f4f887bce86 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 12:35:25 +0100 Subject: [PATCH 116/176] add_coordinate_channels Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 16 +++++--- tests/test_add_coordinate_channels.py | 26 +++++++----- tests/test_add_coordinate_channelsd.py | 56 +++++++++++++++++--------- tests/test_resize.py | 5 ++- 4 files changed, 67 insertions(+), 36 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 2d51353afe..2a835d43f9 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -29,7 +29,6 @@ Randomizable, RandomizableTransform, ThreadUnsafe, - ToDoTransform, TorchOrNumpyTransform, TorchTransform, Transform, @@ -1763,7 +1762,7 @@ def __call__( return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) -class AddCoordinateChannels(ToDoTransform): +class AddCoordinateChannels(TorchOrNumpyTransform): """ Appends additional channels encoding coordinates of the input. Useful when e.g. training using patch-based sampling, to allow feeding of the patch's location into the network. @@ -1785,7 +1784,7 @@ def __init__( """ self.spatial_channels = spatial_channels - def __call__(self, img: DataObjects.Images): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: data to be transformed, assuming `img` is channel first. @@ -1799,8 +1798,15 @@ def __call__(self, img: DataObjects.Images): raise ValueError("cannot add AddCoordinateChannels channel for dimension 0, as 0 is channel dim.") spatial_dims = img.shape[1:] - coord_channels = np.array(np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_dims), indexing="ij")) + coord_channels = np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_dims), indexing="ij") + coord_channels, *_ = convert_data_type( + coord_channels, type(img), device=img.device if isinstance(img, torch.Tensor) else None + ) # only keep required dimensions. need to subtract 1 since im will be 0-based # but user input is 1-based (because channel dim is 0) coord_channels = coord_channels[[s - 1 for s in self.spatial_channels]] - return np.concatenate((img, coord_channels), axis=0) + if isinstance(img, torch.Tensor): + out = torch.cat((img, coord_channels), dim=0) + else: + out = np.concatenate((img, coord_channels), axis=0) + return out diff --git a/tests/test_add_coordinate_channels.py b/tests/test_add_coordinate_channels.py index 3399008e02..e90dca4688 100644 --- a/tests/test_add_coordinate_channels.py +++ b/tests/test_add_coordinate_channels.py @@ -12,32 +12,38 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AddCoordinateChannels +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"spatial_channels": (1, 2, 3)}, np.random.randint(0, 2, size=(1, 3, 3, 3)), (4, 3, 3, 3)] - -TEST_CASE_2 = [{"spatial_channels": (1,)}, np.random.randint(0, 2, size=(1, 3, 3, 3)), (2, 3, 3, 3)] - -TEST_CASE_ERROR_3 = [{"spatial_channels": (3,)}, np.random.randint(0, 2, size=(1, 3, 3))] - -TEST_CASE_ERROR_4 = [{"spatial_channels": (0, 1, 2)}, np.random.randint(0, 2, size=(1, 3, 3))] +TESTS, TEST_CASES_ERROR_1, TEST_CASES_ERROR_2 = [], [], [] +for p in TEST_NDARRAYS: + TESTS.append([{"spatial_channels": (1, 2, 3)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (4, 3, 3, 3)]) + TESTS.append([{"spatial_channels": (1,)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (2, 3, 3, 3)]) + TEST_CASES_ERROR_1.append([{"spatial_channels": (3,)}, p(np.random.randint(0, 2, size=(1, 3, 3)))]) + TEST_CASES_ERROR_2.append([{"spatial_channels": (0, 1, 2)}, p(np.random.randint(0, 2, size=(1, 3, 3)))]) class TestAddCoordinateChannels(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input, expected_shape): result = AddCoordinateChannels(**input_param)(input) + self.assertEqual(type(result), type(input)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input.device) + input = input.cpu() + result = result.cpu() self.assertEqual(list(result.shape), list(expected_shape)) np.testing.assert_array_equal(input[0, ...], result[0, ...]) - @parameterized.expand([TEST_CASE_ERROR_3]) + @parameterized.expand(TEST_CASES_ERROR_1) def test_max_channel(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannels(**input_param)(input) - @parameterized.expand([TEST_CASE_ERROR_4]) + @parameterized.expand(TEST_CASES_ERROR_2) def test_channel_dim(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannels(**input_param)(input) diff --git a/tests/test_add_coordinate_channelsd.py b/tests/test_add_coordinate_channelsd.py index 0fa6aae1c9..cb504983e6 100644 --- a/tests/test_add_coordinate_channelsd.py +++ b/tests/test_add_coordinate_channelsd.py @@ -12,40 +12,56 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AddCoordinateChannelsd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"spatial_channels": (1, 2, 3), "keys": ["img"]}, - {"img": np.random.randint(0, 2, size=(1, 3, 3, 3))}, - (4, 3, 3, 3), -] +TESTS, TEST_CASES_ERROR_1, TEST_CASES_ERROR_2 = [], [], [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"spatial_channels": (1, 2, 3), "keys": ["img"]}, + {"img": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))}, + (4, 3, 3, 3), + ] + ) + TESTS.append( + [ + {"spatial_channels": (1,), "keys": ["img"]}, + {"img": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))}, + (2, 3, 3, 3), + ] + ) -TEST_CASE_2 = [ - {"spatial_channels": (1,), "keys": ["img"]}, - {"img": np.random.randint(0, 2, size=(1, 3, 3, 3))}, - (2, 3, 3, 3), -] - -TEST_CASE_ERROR_3 = [{"spatial_channels": (3,), "keys": ["img"]}, {"img": np.random.randint(0, 2, size=(1, 3, 3))}] - -TEST_CASE_ERROR_4 = [{"spatial_channels": (0, 1, 2), "keys": ["img"]}, {"img": np.random.randint(0, 2, size=(1, 3, 3))}] + TEST_CASES_ERROR_1.append( + [{"spatial_channels": (3,), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3)))}] + ) + TEST_CASES_ERROR_2.append( + [{"spatial_channels": (0, 1, 2), "keys": ["img"]}, {"img": p(np.random.randint(0, 2, size=(1, 3, 3)))}] + ) class TestAddCoordinateChannels(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input, expected_shape): - result = AddCoordinateChannelsd(**input_param)(input) - self.assertEqual(list(result["img"].shape), list(expected_shape)) - np.testing.assert_array_equal(input["img"][0, ...], result["img"][0, ...]) + result = AddCoordinateChannelsd(**input_param)(input)["img"] + input = input["img"] + self.assertEqual(type(result), type(input)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input.device) + input = input.cpu() + result = result.cpu() + self.assertEqual(result.shape, expected_shape) + np.testing.assert_array_equal(input[0, ...], result[0, ...]) - @parameterized.expand([TEST_CASE_ERROR_3]) + @parameterized.expand(TEST_CASES_ERROR_1) def test_max_channel(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannelsd(**input_param)(input) - @parameterized.expand([TEST_CASE_ERROR_4]) + @parameterized.expand(TEST_CASES_ERROR_2) def test_channel_dim(self, input_param, input): with self.assertRaises(ValueError): AddCoordinateChannelsd(**input_param)(input) diff --git a/tests/test_resize.py b/tests/test_resize.py index 56353f2bea..8d6eeddcea 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -54,8 +54,11 @@ def test_correct_results(self, in_type, spatial_size, mode): ) ) expected = np.stack(expected).astype(np.float32) - out = resize(in_type(self.imt[0])) + im = in_type(self.imt[0]) + out = resize(im) + self.assertEqual(type(im), type(out)) if isinstance(out, torch.Tensor): + self.assertEqual(im.device, out.device) out = out.cpu() np.testing.assert_allclose(out, expected, atol=0.9) From b1bb9a15937ce9472b530e817a58b5303bf7e5c0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 12:35:35 +0100 Subject: [PATCH 117/176] no more todotransform Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/transform.py | 6 ------ tests/rb_test_transforms.py | 6 +----- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index c2519d4165..aa40a5fa69 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -283,12 +283,6 @@ class NumpyTransform(Transform): array_class = np.ndarray -class ToDoTransform(Transform): - """Transforms that inherit from this class are still to be updated. This is a temporary class.""" - - pass - - class RandomizableTransform(Randomizable, Transform): """ An interface for handling random state locally, currently based on a class variable `R`, diff --git a/tests/rb_test_transforms.py b/tests/rb_test_transforms.py index a182007954..46e2c478a9 100644 --- a/tests/rb_test_transforms.py +++ b/tests/rb_test_transforms.py @@ -2,7 +2,7 @@ from monai import transforms from monai.transforms import MapTransform, Transform -from monai.transforms.transform import NumpyTransform, ToDoTransform, TorchOrNumpyTransform, TorchTransform +from monai.transforms.transform import NumpyTransform, TorchOrNumpyTransform, TorchTransform class Colours: @@ -49,9 +49,6 @@ def print_colour(t, colour): elif issubclass(obj, NumpyTransform): tr_np += 1 print_colour(f"Numpy: {n}", Colours.yellow) - elif issubclass(obj, ToDoTransform): - tr_todo += 1 - print_colour(f"ToDoTransform: {n}", Colours.purple) else: tr_uncategorised += 1 print_colour(f"Uncategorised: {n}", Colours.red) @@ -59,5 +56,4 @@ def print_colour(t, colour): print_colour(f"Number of TorchOrNumpyTransform: {tr_t_or_np}", Colours.green) print_colour(f"Number of TorchTransform: {tr_t}", Colours.green) print_colour(f"Number of NumpyTransform: {tr_np}", Colours.yellow) -print_colour(f"Number of ToDoTransform: {tr_todo}", Colours.purple) print_colour(f"Number of uncategorised: {tr_uncategorised}", Colours.red) From b7afb91de254ab1b591991c935fd65bdecd397a8 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 13:04:35 +0100 Subject: [PATCH 118/176] remove pre_conv_data and post_conv_data Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 4 ++-- monai/transforms/intensity/array.py | 29 ++++++++++++++++------------- monai/transforms/post/array.py | 26 +++++++++++++++----------- monai/transforms/spatial/array.py | 16 ++++++++++------ monai/transforms/transform.py | 25 +++---------------------- monai/transforms/utility/array.py | 16 +++++++++------- 6 files changed, 55 insertions(+), 61 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index fea46d19e3..fcbe026700 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -701,12 +701,12 @@ def __call__(self, img: DataObjects.Images): slicing doesn't change the channel dim. """ img_np: np.ndarray - img_np, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + img_np, orig_type, orig_device = convert_data_type(img, np.ndarray) # type: ignore box_start, box_end = self.compute_bounding_box(img_np) cropped = self.crop_pad(img_np, box_start, box_end) - cropped = self.post_convert_data(cropped, orig_type, orig_device) + cropped, *_ = convert_data_type(cropped, orig_type, orig_device) if self.return_coords: return cropped, box_start, box_end diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 4a4e01345f..a96d851583 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -883,13 +883,14 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ img_t: torch.Tensor - img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore # add one to transform axis because a batch axis will be added at dimension 0 savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode) # convert to Tensor and add Batch axis expected by HilbertTransform - out = savgol_filter(img_t.unsqueeze(0)).squeeze(0) - return self.post_convert_data(out, orig_type, orig_device) + out_t = savgol_filter(img_t.unsqueeze(0)).squeeze(0) + out, *_ = convert_data_type(out_t, orig_type, orig_device) + return out class DetectEnvelope(TorchTransform): @@ -926,13 +927,14 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ img_t: torch.Tensor - img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore # add one to transform axis because a batch axis will be added at dimension 0 hilbert_transform = HilbertTransform(self.axis + 1, self.n) # convert to Tensor and add Batch axis expected by HilbertTransform - out = torch.abs(hilbert_transform(img_t.unsqueeze(0))).squeeze(0) - return self.post_convert_data(out, orig_type, orig_device) + out_t = torch.abs(hilbert_transform(img_t.unsqueeze(0))).squeeze(0) + out, *_ = convert_data_type(out_t, orig_type, orig_device) + return out class GaussianSmooth(TorchTransform): @@ -955,12 +957,13 @@ def __init__(self, sigma: Union[Sequence[float], float] = 1.0, approx: str = "er def __call__(self, img: DataObjects.Images) -> DataObjects.Images: img_t: torch.Tensor - img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore img_t = img_t.to(torch.float) gaussian_filter = GaussianFilter(img.ndim - 1, self.sigma, approx=self.approx).to(img_t.device) - out = gaussian_filter(img_t.unsqueeze(0)).squeeze(0) - return self.post_convert_data(out, orig_type, orig_device) + out_t = gaussian_filter(img_t.unsqueeze(0)).squeeze(0) + out, *_ = convert_data_type(out_t, orig_type, orig_device) + return out class RandGaussianSmooth(TorchTransform, RandomizableTransform): @@ -1050,8 +1053,7 @@ def __init__( def __call__(self, img: DataObjects.Images) -> DataObjects.Images: img_t: torch.Tensor - img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore - img_t = img_t.to(torch.float) + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=float) # type: ignore gf1, gf2 = [ GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device) @@ -1059,8 +1061,9 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: ] blurred_f = gf1(img_t.unsqueeze(0)) filter_blurred_f = gf2(blurred_f) - out = (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0) - return self.post_convert_data(out, orig_type, orig_device) + out_t = (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0) + out, *_ = convert_data_type(out_t, orig_type, orig_device) + return out class RandGaussianSharpen(TorchTransform, RandomizableTransform): diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 93183c7f28..23d05e7108 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -26,6 +26,7 @@ from monai.transforms.utils import get_largest_connected_component_mask from monai.utils import ensure_tuple from monai.utils.enums import DataObjects +from monai.utils.misc import convert_data_type __all__ = [ "Activations", @@ -89,20 +90,21 @@ def __call__( if other is not None and not callable(other): raise TypeError(f"other must be None or callable but is {type(other).__name__}.") - img, orig_type, orig_device = self.pre_conv_data(img) + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=float) # type: ignore # convert to float as activation must operate on float tensor - img = img.float() # type: ignore if sigmoid or self.sigmoid: - img = torch.sigmoid(img) + img_t = torch.sigmoid(img_t) if softmax or self.softmax: - img = torch.softmax(img, dim=0) + img_t = torch.softmax(img_t, dim=0) act_func = self.other if other is None else other if act_func is not None: - img = act_func(img) + img_t = act_func(img_t) - return self.post_convert_data(img, orig_type, orig_device) + out, *_ = convert_data_type(img_t, orig_type, orig_device) + return out class AsDiscrete(TorchTransform): @@ -168,7 +170,7 @@ def __call__( """ img_t: torch.Tensor - img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore if argmax or self.argmax: img_t = torch.argmax(img_t, dim=0, keepdim=True) @@ -182,7 +184,8 @@ def __call__( if threshold_values or self.threshold_values: img_t = img_t >= (logit_thresh or self.logit_thresh) - return self.post_convert_data(img_t.float(), orig_type, orig_device) + out, *_ = convert_data_type(img, orig_type, orig_device, dtype=float) + return out class KeepLargestConnectedComponent(TorchTransform): @@ -260,7 +263,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: A PyTorch Tensor with shape (C, spatial_dim1[, spatial_dim2, ...]). """ img_t: torch.Tensor - img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore if img_t.shape[0] == 1: img_t = torch.squeeze(img_t, dim=0) @@ -293,9 +296,10 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=0) applied_img[background_mask] = 0 img_t[self.applied_labels, ...] = applied_img.type(img_t.type()) # type: ignore - output = img_t # type: ignore + output = img_t - return self.post_convert_data(output, orig_type, orig_device) + out, *_ = convert_data_type(output, orig_type, orig_device) + return out class LabelToContour(Transform): diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 2a835d43f9..c16066ed73 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -345,11 +345,13 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - img, orig_type, orig_device = self.pre_conv_data(img) + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore - result = torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)).to(img.dtype) # type: ignore + result_t = torch.flip(img_t, map_spatial_axes(img.ndim, self.spatial_axis)) - return self.post_convert_data(result, orig_type, orig_device) + result, *_ = convert_data_type(result_t, orig_type, orig_device) + return result class Resize(TorchTransform): @@ -529,7 +531,8 @@ def __call__( spatial_size=output_shape, ) self._rotation_matrix = transform - return self.post_convert_data(output.squeeze(0).float(), orig_type, orig_device) + out, *_ = convert_data_type(output.squeeze(0).float(), orig_type, orig_device) + return out def get_rotation_matrix(self) -> Optional[np.ndarray]: """ @@ -601,7 +604,7 @@ def __call__( """ img_t: torch.Tensor - img_t, orig_type, orig_device = self.pre_conv_data(img) # type: ignore + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore _zoom = ensure_tuple_rep(self.zoom, img_t.ndim - 1) # match the spatial image dim zoomed = torch.nn.functional.interpolate( # type: ignore @@ -629,7 +632,8 @@ def __call__( zoomed = padder(zoomed) zoomed = zoomed[tuple(slice_vec)] - return self.post_convert_data(zoomed, orig_type, orig_device) + out, *_ = convert_data_type(zoomed, orig_type, orig_device) + return out class Rotate90(TorchOrNumpyTransform): diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index aa40a5fa69..2b9047cfa7 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Generator, Hashable, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Generator, Hashable, Iterable, List, Optional, Tuple, TypeVar, Union import numpy as np import torch @@ -23,7 +23,6 @@ from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple from monai.utils.enums import DataObjects -from monai.utils.misc import convert_data_type __all__ = [ "ThreadUnsafe", @@ -214,9 +213,6 @@ class Transform(ABC): :py:class:`monai.transforms.Compose` """ - array_class: Union[Type[torch.Tensor], Type[np.ndarray]] - array_class = torch.Tensor - @abstractmethod def __call__(self, data: Any): """ @@ -245,21 +241,6 @@ def __call__(self, data: Any): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def pre_conv_data(self, data: DataObjects.Images) -> Tuple[DataObjects.Images, type, Optional[torch.device]]: - """Convert to torch/numpy, as required. Also return the original state so that after the transform, - the data can be reverted to its original type. - """ - data, orig_type, orig_device = convert_data_type(data, self.array_class) - return data, orig_type, orig_device - - @staticmethod - def post_convert_data( - data: DataObjects.Images, output_type: type, ouput_device: Optional[torch.device] = None - ) -> DataObjects.Images: - """Convert back to original type.""" - data, *_ = convert_data_type(data, output_type, ouput_device) - return data - class TorchOrNumpyTransform(Transform): """Transforms that inherit from this class process the input the same regardless of whether @@ -272,7 +253,7 @@ class TorchTransform(Transform): """Most transforms use torch. If the transforms inherits from this class, then that is the case. If the input is not torch, convert to torch and then convert back at the end.""" - array_class = torch.Tensor + pass class NumpyTransform(Transform): @@ -280,7 +261,7 @@ class NumpyTransform(Transform): This means that if the input image is `torch.Tensor`, it will be converted to numpy and then reverted at the end.""" - array_class = np.ndarray + pass class RandomizableTransform(Randomizable, Transform): diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 11f134a624..8c8a9797f4 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -426,9 +426,11 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - img, orig_type, orig_device = self.pre_conv_data(img) - img = img.transpose(self.indices) # type: ignore - return self.post_convert_data(img, orig_type, orig_device) + img_np: np.ndarray + img_np, orig_type, orig_device = convert_data_type(img, np.ndarray) # type: ignore + img_np = img_np.transpose(self.indices) # type: ignore + out, *_ = convert_data_type(img_np, orig_type, orig_device) + return out class SqueezeDim(TorchOrNumpyTransform): @@ -675,7 +677,7 @@ def __call__( merge_channels: whether to use `np.any()` to merge the result on channel dim. if yes, will return a single channel mask with binary data. """ - img, orig_type, orig_device = self.pre_conv_data(img) + img, orig_type, orig_device = convert_data_type(img, np.ndarray) if select_labels is None: select_labels = self.select_labels @@ -687,9 +689,9 @@ def __call__( else: data = np.where(np.in1d(img, select_labels), True, False).reshape(img.shape) - out = np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data - - return self.post_convert_data(out, orig_type, orig_device) # type: ignore + out_np = np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data + out, *_ = convert_data_type(out_np, orig_type, orig_device) + return out class FgBgToIndices(Transform): From 5a58249e148b1d2ca71f0d5403c9dcf470e431e8 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 13:21:19 +0100 Subject: [PATCH 119/176] deepgrow update Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/apps/deepgrow/dataset.py | 63 ++++++++++++++++++++++------------ tests/test_deepgrow_dataset.py | 47 +++++++++++-------------- 2 files changed, 61 insertions(+), 49 deletions(-) diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py index acaeba0bc3..b0095c96e6 100644 --- a/monai/apps/deepgrow/dataset.py +++ b/monai/apps/deepgrow/dataset.py @@ -17,6 +17,7 @@ from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd from monai.utils import GridSampleMode +from monai.utils.misc import convert_data_type def create_dataset( @@ -82,7 +83,7 @@ def create_dataset( if not len(datalist): raise ValueError("Input datalist is empty") - transforms = _default_transforms(image_key, label_key, pixdim) if transforms is None else transforms + transforms = transforms or _default_transforms(image_key, label_key, pixdim) new_datalist = [] for idx in range(len(datalist)): if limit and idx >= limit: @@ -133,25 +134,34 @@ def _default_transforms(image_key, label_key, pixdim): def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): + if vol_image is not None: + vol_image_np, *_ = convert_data_type(vol_image, np.ndarray) + else: + vol_image_np = vol_image + if vol_label is not None: + vol_label_np, *_ = convert_data_type(vol_label, np.ndarray) + else: + vol_label_np = vol_label + data_list = [] - if len(vol_image.shape) == 4: + if len(vol_image_np.shape) == 4: logging.info( "4D-Image, pick only first series; Image: {}; Label: {}".format( - vol_image.shape, vol_label.shape if vol_label is not None else None + vol_image_np.shape, vol_label_np.shape if vol_label_np is not None else None ) ) - vol_image = vol_image[0] - vol_image = np.moveaxis(vol_image, -1, 0) + vol_image_np = vol_image_np[0] + vol_image_np = np.moveaxis(vol_image_np, -1, 0) image_count = 0 label_count = 0 unique_labels_count = 0 - for sid in range(vol_image.shape[0]): - image = vol_image[sid, ...] - label = vol_label[sid, ...] if vol_label is not None else None + for sid in range(vol_image_np.shape[0]): + image = vol_image_np[sid, ...] + label = vol_label_np[sid, ...] if vol_label_np is not None else None - if vol_label is not None and np.sum(label) == 0: + if vol_label_np is not None and np.sum(label) == 0: continue image_file_prefix = "vol_idx_{:0>4d}_slice_{:0>3d}".format(vol_idx, sid) @@ -163,7 +173,7 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): image_count += 1 # Test Data - if vol_label is None: + if vol_label_np is None: data_list.append( { "image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file, @@ -200,9 +210,9 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): logging.info( "{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format( vol_idx, - vol_image.shape, + vol_image_np.shape, image_count, - vol_label.shape if vol_label is not None else None, + vol_label_np.shape if vol_label_np is not None else None, label_count, unique_labels_count, ) @@ -211,16 +221,25 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): + if vol_image is not None: + vol_image_np, *_ = convert_data_type(vol_image, np.ndarray) + else: + vol_image_np = vol_image + if vol_label is not None: + vol_label_np, *_ = convert_data_type(vol_label, np.ndarray) + else: + vol_label_np = vol_label + data_list = [] - if len(vol_image.shape) == 4: + if len(vol_image_np.shape) == 4: logging.info( "4D-Image, pick only first series; Image: {}; Label: {}".format( - vol_image.shape, vol_label.shape if vol_label is not None else None + vol_image_np.shape, vol_label_np.shape if vol_label_np is not None else None ) ) - vol_image = vol_image[0] - vol_image = np.moveaxis(vol_image, -1, 0) + vol_image_np = vol_image_np[0] + vol_image_np = np.moveaxis(vol_image_np, -1, 0) image_count = 0 label_count = 0 @@ -231,11 +250,11 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): image_file += ".npy" os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True) - np.save(image_file, vol_image) + np.save(image_file, vol_image_np) image_count += 1 # Test Data - if vol_label is None: + if vol_label_np is None: data_list.append( { "image": image_file.replace(dataset_dir + os.pathsep, "") if relative_path else image_file, @@ -243,7 +262,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): ) else: # For all Labels - unique_labels = np.unique(vol_label.flatten()) + unique_labels = np.unique(vol_label_np.flatten()) unique_labels = unique_labels[unique_labels != 0] unique_labels_count = max(unique_labels_count, len(unique_labels)) @@ -252,7 +271,7 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): label_file = os.path.join(dataset_dir, "labels", label_file_prefix) label_file += ".npy" - curr_label = (vol_label == idx).astype(np.float32) + curr_label = (vol_label_np == idx).astype(np.float32) os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True) np.save(label_file, curr_label) @@ -271,9 +290,9 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path): logging.info( "{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format( vol_idx, - vol_image.shape, + vol_image_np.shape, image_count, - vol_label.shape if vol_label is not None else None, + vol_label_np.shape if vol_label_np is not None else None, label_count, unique_labels_count, ) diff --git a/tests/test_deepgrow_dataset.py b/tests/test_deepgrow_dataset.py index 147d8e7099..d35c4c7e91 100644 --- a/tests/test_deepgrow_dataset.py +++ b/tests/test_deepgrow_dataset.py @@ -21,30 +21,25 @@ from monai.apps.deepgrow.dataset import create_dataset from monai.utils import set_determinism -TEST_CASE_1 = [{"dimension": 2, "pixdim": (1, 1)}, {"length": 3}, 9, 1] - -TEST_CASE_2 = [{"dimension": 2, "pixdim": (1, 1), "limit": 1}, {"length": 3}, 3, 1] - -TEST_CASE_3 = [{"dimension": 2, "pixdim": (1, 1)}, {"length": 1}, 3, 1] - -TEST_CASE_4 = [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1}, 1, 1] - -TEST_CASE_5 = [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1, "image_channel": 4}, 1, 1] - -TEST_CASE_6 = [{"dimension": 2, "pixdim": (1, 1)}, {"length": 1, "image_channel": 4}, 3, 1] - -TEST_CASE_7 = [ - {"dimension": 2, "pixdim": (1, 1), "label_key": None}, - {"length": 1, "image_channel": 4, "with_label": False}, - 40, - None, -] - -TEST_CASE_8 = [ - {"dimension": 3, "pixdim": (1, 1, 1), "label_key": None}, - {"length": 1, "image_channel": 4, "with_label": False}, - 1, - None, +TESTS = [ + [{"dimension": 2, "pixdim": (1, 1)}, {"length": 3}, 9, 1], + [{"dimension": 2, "pixdim": (1, 1), "limit": 1}, {"length": 3}, 3, 1], + [{"dimension": 2, "pixdim": (1, 1)}, {"length": 1}, 3, 1], + [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1}, 1, 1], + [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1, "image_channel": 4}, 1, 1], + [{"dimension": 2, "pixdim": (1, 1)}, {"length": 1, "image_channel": 4}, 3, 1], + [ + {"dimension": 2, "pixdim": (1, 1), "label_key": None}, + {"length": 1, "image_channel": 4, "with_label": False}, + 40, + None, + ], + [ + {"dimension": 3, "pixdim": (1, 1, 1), "label_key": None}, + {"length": 1, "image_channel": 4, "with_label": False}, + 1, + None, + ], ] @@ -79,9 +74,7 @@ def _create_data(self, length=1, image_channel=1, with_label=True): return datalist - @parameterized.expand( - [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] - ) + @parameterized.expand(TESTS) def test_create_dataset(self, args, data_args, expected_length, expected_region): datalist = self._create_data(**data_args) deepgrow_datalist = create_dataset(datalist=datalist, output_dir=self.tempdir, **args) From 5a14c2dc181ba9595200773a419baa3738b0d082 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 14:15:14 +0100 Subject: [PATCH 120/176] no TorchOrNumpyTransform Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 16 +++++++-------- monai/transforms/intensity/array.py | 30 ++++++++++++++--------------- monai/transforms/spatial/array.py | 18 +++++------------ monai/transforms/transform.py | 7 ------- monai/transforms/utility/array.py | 28 +++++++++++++-------------- tests/rb_test_transforms.py | 6 +++--- 6 files changed, 45 insertions(+), 60 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index fcbe026700..f028b5c4b4 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -23,7 +23,7 @@ from monai.config import IndexSelection from monai.data.utils import get_random_patch, get_valid_patch_size -from monai.transforms.transform import NumpyTransform, Randomizable, TorchOrNumpyTransform, Transform +from monai.transforms.transform import NumpyTransform, Randomizable, TorchTransform, Transform from monai.transforms.utils import ( compute_divisible_spatial_size, generate_label_classes_crop_centers, @@ -58,7 +58,7 @@ ] -class Pad(TorchOrNumpyTransform): +class Pad(TorchTransform, NumpyTransform): """ Perform padding for a given an amount of padding in each dimension. @@ -122,7 +122,7 @@ def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, s return pad(img, self.to_pad, mode, **self.np_kwargs) -class SpatialPad(TorchOrNumpyTransform): +class SpatialPad(TorchTransform, NumpyTransform): """ Performs padding to the data, symmetric for all sides or all on one side for each dimension. @@ -190,7 +190,7 @@ def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, s return padder(img) -class BorderPad(TorchOrNumpyTransform): +class BorderPad(TorchTransform, NumpyTransform): """ Pad the input data by adding specified borders to every dimension. @@ -263,7 +263,7 @@ def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, s return padder(img) -class DivisiblePad(TorchOrNumpyTransform): +class DivisiblePad(TorchTransform, NumpyTransform): """ Pad the input data, so that the spatial sizes are divisible by `k`. """ @@ -317,7 +317,7 @@ def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, s return spatial_pad(img) -class SpatialCrop(TorchOrNumpyTransform): +class SpatialCrop(TorchTransform, NumpyTransform): """ General purpose cropper to produce sub-volume region of interest (ROI). If a dimension of the expected ROI size is bigger than the input image size, will not crop that dimension. @@ -380,7 +380,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img[tuple(slices)] -class CenterSpatialCrop(TorchOrNumpyTransform): +class CenterSpatialCrop(TorchTransform, NumpyTransform): """ Crop at the center of image with specified ROI size. If a dimension of the expected ROI size is bigger than the input image size, will not crop that dimension. @@ -901,7 +901,7 @@ def __call__( return results -class RandCropByLabelClasses(Randomizable, TorchOrNumpyTransform): +class RandCropByLabelClasses(Randomizable, TorchTransform, NumpyTransform): """ Crop random fixed sized regions with the center being a class based on the specified ratios of every class. The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index a96d851583..b9aa45010f 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -24,7 +24,7 @@ from monai.config import DtypeLike from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter -from monai.transforms.transform import RandomizableTransform, TorchOrNumpyTransform, TorchTransform, Transform +from monai.transforms.transform import NumpyTransform, RandomizableTransform, TorchTransform, Transform from monai.transforms.utils import rescale_array from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, ensure_tuple, ensure_tuple_rep, ensure_tuple_size from monai.utils.enums import DataObjects @@ -94,7 +94,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img + self._noise.astype(dtype) -class RandRicianNoise(TorchOrNumpyTransform, RandomizableTransform): +class RandRicianNoise(TorchTransform, NumpyTransform, RandomizableTransform): """ Add Rician noise to image. Rician noise in MRI is the result of performing a magnitude operation on complex @@ -173,7 +173,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img -class ShiftIntensity(TorchOrNumpyTransform): +class ShiftIntensity(TorchTransform, NumpyTransform): """ Shift intensity uniformly for the entire image with specified `offset`. @@ -194,7 +194,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return out.astype(img.dtype) # type: ignore -class RandShiftIntensity(RandomizableTransform, TorchOrNumpyTransform): +class RandShiftIntensity(RandomizableTransform, TorchTransform, NumpyTransform): """ Randomly shift intensity with randomly picked offset. """ @@ -230,7 +230,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return shifter(img) -class StdShiftIntensity(TorchOrNumpyTransform): +class StdShiftIntensity(TorchTransform, NumpyTransform): """ Shift intensity for the image with a factor and the standard deviation of the image by: ``v = v + factor * std(v)``. @@ -284,7 +284,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img -class RandStdShiftIntensity(TorchOrNumpyTransform, RandomizableTransform): +class RandStdShiftIntensity(TorchTransform, NumpyTransform, RandomizableTransform): """ Shift intensity for the image with a factor and the standard deviation of the image by: ``v = v + factor * std(v)`` where the `factor` is randomly picked. @@ -337,7 +337,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return shifter(img) -class ScaleIntensity(TorchOrNumpyTransform): +class ScaleIntensity(TorchTransform, NumpyTransform): """ Scale the intensity of input image to the given value range (minv, maxv). If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``. @@ -376,7 +376,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.") -class RandScaleIntensity(TorchOrNumpyTransform, RandomizableTransform): +class RandScaleIntensity(TorchTransform, NumpyTransform, RandomizableTransform): """ Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor` is randomly picked. @@ -604,7 +604,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: ) -class ScaleIntensityRange(TorchOrNumpyTransform): +class ScaleIntensityRange(TorchTransform, NumpyTransform): """ Apply specific intensity scaling to the whole numpy array. Scaling from [a_min, a_max] to [b_min, b_max] with clip option. @@ -640,7 +640,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img -class AdjustContrast(TorchOrNumpyTransform): +class AdjustContrast(TorchTransform, NumpyTransform): """ Changes image intensity by gamma. Each pixel/voxel intensity is updated as:: @@ -665,7 +665,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return ((img - img_min) / float(img_range + epsilon)) ** self.gamma * img_range + img_min -class RandAdjustContrast(TorchOrNumpyTransform, RandomizableTransform): +class RandAdjustContrast(TorchTransform, NumpyTransform, RandomizableTransform): """ Randomly changes image intensity by gamma. Each pixel/voxel intensity is updated as:: @@ -710,7 +710,7 @@ def __call__(self, img: np.ndarray): return adjuster(img) -class ScaleIntensityRangePercentiles(TorchOrNumpyTransform): +class ScaleIntensityRangePercentiles(TorchTransform, NumpyTransform): """ Apply range scaling to a numpy array based on the intensity distribution of the input. @@ -802,7 +802,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img -class MaskIntensity(TorchOrNumpyTransform): +class MaskIntensity(TorchTransform, NumpyTransform): """ Mask the intensity values of input image with the specified mask data. Mask data must have the same spatial size as the input image, and all @@ -1179,7 +1179,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: ) -class RandGibbsNoise(TorchOrNumpyTransform, RandomizableTransform): +class RandGibbsNoise(TorchTransform, NumpyTransform, RandomizableTransform): """ Naturalistic image augmentation via Gibbs artifacts. The transform randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts @@ -1234,7 +1234,7 @@ def _randomize(self, _: Any) -> None: self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1]) -class GibbsNoise(TorchOrNumpyTransform): +class GibbsNoise(TorchTransform, NumpyTransform): """ The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts are one of the common type of type artifacts appearing in MRI scans. diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c16066ed73..f588a64fd2 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -24,15 +24,7 @@ from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.croppad.array import CenterSpatialCrop, Pad -from monai.transforms.transform import ( - NumpyTransform, - Randomizable, - RandomizableTransform, - ThreadUnsafe, - TorchOrNumpyTransform, - TorchTransform, - Transform, -) +from monai.transforms.transform import NumpyTransform, Randomizable, RandomizableTransform, ThreadUnsafe, TorchTransform from monai.transforms.utils import ( create_control_grid, create_grid, @@ -636,7 +628,7 @@ def __call__( return out -class Rotate90(TorchOrNumpyTransform): +class Rotate90(TorchTransform, NumpyTransform): """ Rotate an array by 90 degrees in the plane specified by `axes`. See np.rot90 for additional details: @@ -668,7 +660,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return np.rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)).astype(img.dtype) # type: ignore -class RandRotate90(TorchOrNumpyTransform, RandomizableTransform): +class RandRotate90(TorchTransform, NumpyTransform, RandomizableTransform): """ With probability `prob`, input arrays are rotated by 90 degrees in the plane specified by `spatial_axes`. @@ -1146,7 +1138,7 @@ def get_transformation_matrix(self) -> Optional[DataObjects.Images]: return self.affine -class RandDeformGrid(Randomizable, Transform): +class RandDeformGrid(Randomizable, TorchTransform, NumpyTransform): """ Generate random deformation grid. """ @@ -1766,7 +1758,7 @@ def __call__( return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) -class AddCoordinateChannels(TorchOrNumpyTransform): +class AddCoordinateChannels(TorchTransform, NumpyTransform): """ Appends additional channels encoding coordinates of the input. Useful when e.g. training using patch-based sampling, to allow feeding of the patch's location into the network. diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 2b9047cfa7..ee3b1c092c 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -242,13 +242,6 @@ def __call__(self, data: Any): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class TorchOrNumpyTransform(Transform): - """Transforms that inherit from this class process the input the same regardless of whether - the input is torch or numpy. No conversions are needed.""" - - pass - - class TorchTransform(Transform): """Most transforms use torch. If the transforms inherits from this class, then that is the case. If the input is not torch, convert to torch and then convert back at the end.""" diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 8c8a9797f4..5edb875c5a 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -23,7 +23,7 @@ import torch from monai.config import DtypeLike -from monai.transforms.transform import NumpyTransform, Randomizable, TorchOrNumpyTransform, TorchTransform, Transform +from monai.transforms.transform import NumpyTransform, Randomizable, TorchTransform, Transform from monai.transforms.utils import ( convert_to_numpy, convert_to_tensor, @@ -70,7 +70,7 @@ ] -class Identity(TorchOrNumpyTransform): +class Identity(TorchTransform, NumpyTransform): """ Convert the input to an np.ndarray, if input data is np.ndarray or subclasses, return unchanged data. As the output value is same as input, it can be used as a testing tool to verify the transform chain, @@ -85,7 +85,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img -class AsChannelFirst(TorchOrNumpyTransform): +class AsChannelFirst(TorchTransform, NumpyTransform): """ Change the channel dimension of the image to the first dimension. @@ -118,7 +118,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return convert_data_type(img, orig_type, orig_device)[0] -class AsChannelLast(TorchOrNumpyTransform): +class AsChannelLast(TorchTransform, NumpyTransform): """ Change the channel dimension of the image to the last dimension. @@ -150,7 +150,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return convert_data_type(img, orig_type, orig_device)[0] -class AddChannel(TorchOrNumpyTransform): +class AddChannel(TorchTransform, NumpyTransform): """ Adds a 1-length channel dimension to the input image. @@ -210,7 +210,7 @@ def __call__(self, img: DataObjects.Images, meta_dict: Optional[Dict] = None) -> return AsChannelFirst(channel_dim=channel_dim)(img) -class RepeatChannel(TorchOrNumpyTransform): +class RepeatChannel(TorchTransform, NumpyTransform): """ Repeat channel data to construct expected input shape for models. The `repeats` count includes the origin data, for example: @@ -233,7 +233,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return repeeat_fn(img, self.repeats, 0) # type: ignore -class RemoveRepeatedChannel(TorchOrNumpyTransform): +class RemoveRepeatedChannel(TorchTransform, NumpyTransform): """ RemoveRepeatedChannel data to undo RepeatChannel The `repeats` count specifies the deletion of the origin data, for example: @@ -259,7 +259,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img[:: self.repeats, :] -class SplitChannel(TorchOrNumpyTransform): +class SplitChannel(TorchTransform, NumpyTransform): """ Split Numpy array or PyTorch Tensor data according to the channel dim. It can help applying different following transforms to different channels. @@ -286,7 +286,7 @@ def __call__(self, img: DataObjects.Images) -> List[DataObjects.Images]: return outputs -class CastToType(TorchOrNumpyTransform): +class CastToType(TorchTransform, NumpyTransform): """ Cast the Numpy data to specified numpy data type, or cast the PyTorch Tensor to specified PyTorch data type. @@ -433,7 +433,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return out -class SqueezeDim(TorchOrNumpyTransform): +class SqueezeDim(TorchTransform, NumpyTransform): """ Squeeze a unitary dimension. """ @@ -465,7 +465,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return img.squeeze(self.dim) -class DataStats(TorchOrNumpyTransform): +class DataStats(TorchTransform, NumpyTransform): """ Utility transform to show the statistics of data for debug or analysis. It can be inserted into any place of a transform chain and check results of previous transforms. @@ -556,7 +556,7 @@ def __call__( return img -class SimulateDelay(TorchOrNumpyTransform): +class SimulateDelay(TorchTransform, NumpyTransform): """ This is a pass through transform to be used for testing purposes. It allows adding fake behaviors that are useful for testing purposes to simulate @@ -736,7 +736,7 @@ def __call__( return fg_indices, bg_indices -class ClassesToIndices(TorchOrNumpyTransform): +class ClassesToIndices(TorchTransform, NumpyTransform): def __init__( self, num_classes: Optional[int] = None, @@ -812,7 +812,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return np.stack(result, axis=0) -class AddExtremePointsChannel(Randomizable, Transform): +class AddExtremePointsChannel(Randomizable, TorchTransform): """ Add extreme points of label to the image as a new channel. This transform generates extreme point from label and applies a gaussian filter. The pixel values in points image are rescaled diff --git a/tests/rb_test_transforms.py b/tests/rb_test_transforms.py index 46e2c478a9..798f000345 100644 --- a/tests/rb_test_transforms.py +++ b/tests/rb_test_transforms.py @@ -2,7 +2,7 @@ from monai import transforms from monai.transforms import MapTransform, Transform -from monai.transforms.transform import NumpyTransform, TorchOrNumpyTransform, TorchTransform +from monai.transforms.transform import NumpyTransform, TorchTransform class Colours: @@ -40,7 +40,7 @@ def print_colour(t, colour): ]: continue tr_total += 1 - if issubclass(obj, TorchOrNumpyTransform): + if issubclass(obj, TorchTransform) and issubclass(obj, NumpyTransform): tr_t_or_np += 1 print_colour(f"TorchOrNumpy: {n}", Colours.green) elif issubclass(obj, TorchTransform): @@ -53,7 +53,7 @@ def print_colour(t, colour): tr_uncategorised += 1 print_colour(f"Uncategorised: {n}", Colours.red) print("Total number of transforms:", tr_total) -print_colour(f"Number of TorchOrNumpyTransform: {tr_t_or_np}", Colours.green) +print_colour(f"Number transforms allowing both torch and numpy: {tr_t_or_np}", Colours.green) print_colour(f"Number of TorchTransform: {tr_t}", Colours.green) print_colour(f"Number of NumpyTransform: {tr_np}", Colours.yellow) print_colour(f"Number of uncategorised: {tr_uncategorised}", Colours.red) From 2b82e87261a2f776fe40a4a96f8c2b2bafbb3ec6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 14:15:47 +0100 Subject: [PATCH 121/176] AddExtremePointsChannel Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 25 ++++--- monai/transforms/utility/dictionary.py | 8 ++- monai/transforms/utils.py | 27 +++++--- tests/test_add_extreme_points_channel.py | 83 +++++++++++++---------- tests/test_add_extreme_points_channeld.py | 70 ++++++++++++------- tests/test_get_extreme_points.py | 48 +++++++------ 6 files changed, 158 insertions(+), 103 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 5edb875c5a..765caf8395 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -837,17 +837,17 @@ def __init__(self, background: int = 0, pert: float = 0.0) -> None: self._pert = pert self._points: List[Tuple[int, ...]] = [] - def randomize(self, label: np.ndarray) -> None: + def randomize(self, label: DataObjects.Images) -> None: self._points = get_extreme_points(label, rand_state=self.R, background=self._background, pert=self._pert) def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, + img: DataObjects.Images, + label: Optional[DataObjects.Images] = None, sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0, rescale_min: float = -1.0, rescale_max: float = 1.0, - ): + ) -> DataObjects.Images: """ Args: img: the image that we want to add new channel to. @@ -864,14 +864,23 @@ def __call__( if label.shape[0] != 1: raise ValueError("Only supports single channel labels!") + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore + # Generate extreme points self.randomize(label[0, :]) - points_image = extreme_points_to_image( - points=self._points, label=label, sigma=sigma, rescale_min=rescale_min, rescale_max=rescale_max + points_image_t = extreme_points_to_image( + points=self._points, + label=label, + sigma=sigma, + rescale_min=rescale_min, + rescale_max=rescale_max, + device=img_t.device, ) - - return np.concatenate([img, points_image], axis=0) + out_t = torch.cat((img_t, points_image_t), dim=0) + out, *_ = convert_data_type(out_t, orig_type, orig_device) + return out class TorchVision: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 494392165d..eba14b475b 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1076,7 +1076,7 @@ def __init__( self.rescale_min = rescale_min self.rescale_max = rescale_max - def randomize(self, label: np.ndarray) -> None: + def randomize(self, label: DataObjects.Images) -> None: self.points = get_extreme_points(label, rand_state=self.R, background=self.background, pert=self.pert) def __call__(self, data): @@ -1096,8 +1096,12 @@ def __call__(self, data): sigma=self.sigma, rescale_min=self.rescale_min, rescale_max=self.rescale_max, + device=img.device if isinstance(img, torch.Tensor) else None, ) - d[key] = np.concatenate([img, points_image], axis=0) + if isinstance(img, np.ndarray): + d[key] = np.concatenate([img, points_image], axis=0) + else: + d[key] = torch.cat((img, points_image), dim=0) return d diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index a060fa4fc2..7a6562f18d 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -738,7 +738,7 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option def get_extreme_points( - img: np.ndarray, rand_state: Optional[np.random.RandomState] = None, background: int = 0, pert: float = 0.0 + img: DataObjects.Images, rand_state: Optional[np.random.RandomState] = None, background: int = 0, pert: float = 0.0 ) -> List[Tuple[int, ...]]: """ Generate extreme points from an image. These are used to generate initial segmentation @@ -762,9 +762,12 @@ def get_extreme_points( """ if rand_state is None: rand_state = np.random.random.__self__ # type: ignore - indices = np.where(img != background) + where = np.where if isinstance(img, np.ndarray) else torch.where + indices = where(img != background) if np.size(indices[0]) == 0: raise ValueError("get_extreme_points: no foreground object in mask!") + if isinstance(img, torch.Tensor): + indices = tuple(i.cpu() for i in indices) def _get_point(val, dim): """ @@ -786,19 +789,20 @@ def _get_point(val, dim): points = [] for i in range(img.ndim): - points.append(tuple(_get_point(np.min(indices[i][...]), i))) - points.append(tuple(_get_point(np.max(indices[i][...]), i))) + points.append(tuple(_get_point(indices[i][...].min(), i))) + points.append(tuple(_get_point(indices[i][...].max(), i))) return points def extreme_points_to_image( points: List[Tuple[int, ...]], - label: np.ndarray, + label: DataObjects.Images, sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, rescale_min: float = -1.0, rescale_max: float = 1.0, -): + device: Optional[Union[torch.device, str]] = None, +) -> torch.Tensor: """ Please refer to :py:class:`monai.transforms.AddExtremePointsChannel` for the usage. @@ -814,20 +818,21 @@ def extreme_points_to_image( use it for all spatial dimensions. rescale_min: minimum value of output data. rescale_max: maximum value of output data. + device: device on which to return result """ # points to image - points_image = torch.zeros(label.shape[1:], dtype=torch.float) + points_image = torch.zeros(label.shape[1:], dtype=torch.float, device=device) for p in points: points_image[p] = 1.0 # add channel and add batch points_image = points_image.unsqueeze(0).unsqueeze(0) - gaussian_filter = GaussianFilter(label.ndim - 1, sigma=sigma) - points_image = gaussian_filter(points_image).squeeze(0).detach().numpy() + gaussian_filter = GaussianFilter(label.ndim - 1, sigma=sigma).to(device) + points_image = gaussian_filter(points_image).squeeze(0).detach() # rescale the points image to [rescale_min, rescale_max] - min_intensity = np.min(points_image) - max_intensity = np.max(points_image) + min_intensity = points_image.min() + max_intensity = points_image.max() points_image = (points_image - min_intensity) / (max_intensity - min_intensity) points_image = points_image * (rescale_max - rescale_min) + rescale_min return points_image diff --git a/tests/test_add_extreme_points_channel.py b/tests/test_add_extreme_points_channel.py index ecf2c83d3c..c062f809ba 100644 --- a/tests/test_add_extreme_points_channel.py +++ b/tests/test_add_extreme_points_channel.py @@ -12,54 +12,67 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AddExtremePointsChannel +from tests.utils import TEST_NDARRAYS IMG_CHANNEL = 3 -TEST_CASE_1 = [ - { - "img": np.zeros((IMG_CHANNEL, 4, 3)), - "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]]), - "sigma": 1.0, - "rescale_min": 0.0, - "rescale_max": 1.0, - }, - np.array( - [ - [0.38318458, 0.98615628, 0.85551184], - [0.35422316, 0.94430935, 1.0], - [0.46000731, 0.57319659, 0.46000722], - [0.64577687, 0.38318464, 0.0], - ] - ), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: -TEST_CASE_2 = [ - { - "img": np.zeros((IMG_CHANNEL, 4, 3)), - "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]]), - "sigma": 1.0, - "rescale_min": 0.0, - "rescale_max": 1.0, - }, - np.array( - [ - [0.44628328, 0.80495411, 0.44628328], - [0.6779086, 1.0, 0.67790854], - [0.33002687, 0.62079221, 0.33002687], - [0.0, 0.31848389, 0.0], - ] - ), -] + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])), + "sigma": 1.0, + "rescale_min": 0.0, + "rescale_max": 1.0, + }, + np.array( + [ + [0.38318458, 0.98615628, 0.85551184], + [0.35422316, 0.94430935, 1.0], + [0.46000731, 0.57319659, 0.46000722], + [0.64577687, 0.38318464, 0.0], + ] + ), + ] + ) + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])), + "sigma": 1.0, + "rescale_min": 0.0, + "rescale_max": 1.0, + }, + np.array( + [ + [0.44628328, 0.80495411, 0.44628328], + [0.6779086, 1.0, 0.67790854], + [0.33002687, 0.62079221, 0.33002687], + [0.0, 0.31848389, 0.0], + ] + ), + ] + ) class TestAddExtremePointsChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChannel() result = add_extreme_points_channel(**input_data) + self.assertEqual(type(input_data["img"]), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input_data["img"].device) + result = result.cpu() np.testing.assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4) diff --git a/tests/test_add_extreme_points_channeld.py b/tests/test_add_extreme_points_channeld.py index e33bb0838c..30cdb5ee82 100644 --- a/tests/test_add_extreme_points_channeld.py +++ b/tests/test_add_extreme_points_channeld.py @@ -12,45 +12,63 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import AddExtremePointsChanneld +from tests.utils import TEST_NDARRAYS IMG_CHANNEL = 3 -TEST_CASE_1 = [ - {"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])}, - np.array( - [ - [0.38318458, 0.98615628, 0.85551184], - [0.35422316, 0.94430935, 1.0], - [0.46000731, 0.57319659, 0.46000722], - [0.64577687, 0.38318464, 0.0], - ] - ), -] - -TEST_CASE_2 = [ - {"img": np.zeros((IMG_CHANNEL, 4, 3)), "label": np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])}, - np.array( - [ - [0.44628328, 0.80495411, 0.44628328], - [0.6779086, 1.0, 0.67790854], - [0.33002687, 0.62079221, 0.33002687], - [0.0, 0.31848389, 0.0], - ] - ), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])), + }, + np.array( + [ + [0.38318458, 0.98615628, 0.85551184], + [0.35422316, 0.94430935, 1.0], + [0.46000731, 0.57319659, 0.46000722], + [0.64577687, 0.38318464, 0.0], + ] + ), + ] + ) + TESTS.append( + [ + { + "img": p(np.zeros((IMG_CHANNEL, 4, 3))), + "label": q(np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])), + }, + np.array( + [ + [0.44628328, 0.80495411, 0.44628328], + [0.6779086, 1.0, 0.67790854], + [0.33002687, 0.62079221, 0.33002687], + [0.0, 0.31848389, 0.0], + ] + ), + ] + ) class TestAddExtremePointsChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_correct_results(self, input_data, expected): add_extreme_points_channel = AddExtremePointsChanneld( keys="img", label_key="label", sigma=1.0, rescale_min=0.0, rescale_max=1.0 ) - result = add_extreme_points_channel(input_data) - np.testing.assert_allclose(result["img"][IMG_CHANNEL], expected, rtol=1e-4) + result = add_extreme_points_channel(input_data)["img"] + + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input_data["img"].device) + result = result.cpu() + np.testing.assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_get_extreme_points.py b/tests/test_get_extreme_points.py index a334c12415..af655bfb86 100644 --- a/tests/test_get_extreme_points.py +++ b/tests/test_get_extreme_points.py @@ -15,30 +15,36 @@ from parameterized import parameterized from monai.transforms import get_extreme_points - -TEST_CASE_1 = [ - { - "img": np.array([[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]), - "rand_state": np.random, - "background": 0, - "pert": 0.0, - }, - [(0, 1), (3, 0), (3, 0), (1, 2)], -] - -TEST_CASE_2 = [ - { - "img": np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]), - "rand_state": np.random, - "background": 0, - "pert": 0.0, - }, - [(0, 1), (3, 1), (1, 0), (1, 2)], -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p(np.array([[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]])), + "rand_state": np.random, + "background": 0, + "pert": 0.0, + }, + [(0, 1), (3, 0), (3, 0), (1, 2)], + ] + ) + TESTS.append( + [ + { + "img": p(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]])), + "rand_state": np.random, + "background": 0, + "pert": 0.0, + }, + [(0, 1), (3, 1), (1, 0), (1, 2)], + ] + ) class TestGetExtremePoints(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected): result = get_extreme_points(**input_data) self.assertEqual(result, expected) From 912b595cf5ef7187e018c2d9aac8499a70d234c0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 14:47:24 +0100 Subject: [PATCH 122/176] probnms Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/post/array.py | 31 ++++---- tests/test_probnms.py | 131 ++++++++++++++++--------------- tests/test_probnmsd.py | 136 ++++++++++++++++----------------- 3 files changed, 144 insertions(+), 154 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 23d05e7108..d887e6e949 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -14,7 +14,8 @@ """ import warnings -from typing import Callable, Optional, Sequence, Union +from copy import deepcopy +from typing import Callable, List, Optional, Sequence, Union import numpy as np import torch @@ -438,7 +439,7 @@ def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Te return torch.round(img_) -class ProbNMS(Transform): +class ProbNMS(TorchTransform): """ Performs probability based non-maximum suppression (NMS) on the probabilities map via iteratively selecting the coordinate with highest probability and then move it as well @@ -498,28 +499,22 @@ def __init__( def __call__( self, prob_map: DataObjects.Images, - ): + ) -> List[List]: """ prob_map: the input probabilities map, it must have shape (H[, W, ...]). """ + prob_map_t: torch.Tensor + prob_map_t, *_ = convert_data_type(deepcopy(prob_map), torch.Tensor, dtype=float) # type: ignore if self.sigma != 0: - if not isinstance(prob_map, torch.Tensor): - prob_map = torch.as_tensor(prob_map, dtype=torch.float) - self.filter.to(prob_map) - prob_map = self.filter(prob_map) - else: - if not isinstance(prob_map, torch.Tensor): - prob_map = prob_map.copy() - - if isinstance(prob_map, torch.Tensor): - prob_map = prob_map.detach().cpu().numpy() + self.filter.to(prob_map_t) + prob_map_t = self.filter(prob_map_t) - prob_map_shape = prob_map.shape + prob_map_shape = prob_map_t.shape outputs = [] - while np.max(prob_map) > self.prob_threshold: - max_idx = np.unravel_index(prob_map.argmax(), prob_map_shape) - prob_max = prob_map[max_idx] + while prob_map_t.max() > self.prob_threshold: + max_idx = np.unravel_index(prob_map_t.argmax().cpu(), prob_map_shape) + prob_max = prob_map_t[max_idx].item() max_idx_arr = np.asarray(max_idx) outputs.append([prob_max] + list(max_idx_arr)) @@ -527,6 +522,6 @@ def __call__( idx_max_range = (max_idx_arr + self.box_upper_bd).clip(None, prob_map_shape) # for each dimension, set values during index ranges to 0 slices = tuple(slice(idx_min_range[i], idx_max_range[i]) for i in range(self.spatial_dims)) - prob_map[slices] = 0 + prob_map_t[slices] = 0 return outputs diff --git a/tests/test_probnms.py b/tests/test_probnms.py index e51d1017d8..e43b7ddcfb 100644 --- a/tests/test_probnms.py +++ b/tests/test_probnms.py @@ -16,83 +16,80 @@ from parameterized import parameterized from monai.transforms.post.array import ProbNMS +from tests.utils import TEST_NDARRAYS -probs_map_1 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_1 = [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, probs_map_1, []] +TESTS = [] +for p in TEST_NDARRAYS: + probs_map_1 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, probs_map_1, []]) -probs_map_2 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_2[33, 33] = 0.7 -probs_map_2[66, 66] = 0.9 -expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_2 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, - probs_map_2, - expected_2, -] - -probs_map_3 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_3[56, 58] = 0.7 -probs_map_3[60, 66] = 0.8 -probs_map_3[66, 66] = 0.9 -expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] -TEST_CASES_2D_3 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, - probs_map_3, - expected_3, -] - -probs_map_4 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_4[33, 33] = 0.7 -probs_map_4[66, 66] = 0.9 -expected_4 = [[0.9, 66, 66]] -TEST_CASES_2D_4 = [ - {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, - probs_map_4, - expected_4, -] - -probs_map_5 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_5 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_5, []] + probs_map_2 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_2[33, 33] = 0.7 + probs_map_2[66, 66] = 0.9 + expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, + probs_map_2, + expected_2, + ] + ) -probs_map_6 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -TEST_CASES_2D_6 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_6, []] + probs_map_3 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_3[56, 58] = 0.7 + probs_map_3[60, 66] = 0.8 + probs_map_3[66, 66] = 0.9 + expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, + probs_map_3, + expected_3, + ] + ) -probs_map_7 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -probs_map_7[33, 33] = 0.7 -probs_map_7[66, 66] = 0.9 -if torch.cuda.is_available(): - probs_map_7 = probs_map_7.cuda() -expected_7 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_7 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, - probs_map_7, - expected_7, -] + probs_map_4 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_4[33, 33] = 0.7 + probs_map_4[66, 66] = 0.9 + expected_4 = [[0.9, 66, 66]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, + probs_map_4, + expected_4, + ] + ) -probs_map_3d = torch.rand([50, 50, 50]).uniform_(0, 0.5) -probs_map_3d[25, 25, 25] = 0.7 -probs_map_3d[45, 45, 45] = 0.9 -expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] -TEST_CASES_3D = [ - {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, - probs_map_3d, - expected_3d, -] + probs_map_5 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, probs_map_5, []]) + probs_map_6 = p((np.random.rand(100, 100).clip(0, 0.5))) + probs_map_6[33, 33] = 0.7 + probs_map_6[66, 66] = 0.9 + expected_6 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, + probs_map_6, + expected_6, + ] + ) -class TestProbNMS(unittest.TestCase): - @parameterized.expand( + probs_map_3d = p(torch.rand([50, 50, 50]).uniform_(0, 0.5)) + probs_map_3d[25, 25, 25] = 0.7 + probs_map_3d[45, 45, 45] = 0.9 + expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] + TESTS.append( [ - TEST_CASES_2D_1, - TEST_CASES_2D_2, - TEST_CASES_2D_3, - TEST_CASES_2D_4, - TEST_CASES_2D_5, - TEST_CASES_2D_6, - TEST_CASES_2D_7, - TEST_CASES_3D, + {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, + probs_map_3d, + expected_3d, ] ) + + +class TestProbNMS(unittest.TestCase): + @parameterized.expand(TESTS) def test_output(self, class_args, probs_map, expected): nms = ProbNMS(**class_args) output = nms(probs_map) diff --git a/tests/test_probnmsd.py b/tests/test_probnmsd.py index 5b75d4310f..c9ae7c6a46 100644 --- a/tests/test_probnmsd.py +++ b/tests/test_probnmsd.py @@ -10,91 +10,89 @@ # limitations under the License. import unittest +from typing import List import numpy as np import torch from parameterized import parameterized -from monai.transforms.post.dictionary import ProbNMSD +from monai.transforms.post.dictionary import ProbNMSd +from tests.utils import TEST_NDARRAYS -probs_map_1 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_1 = [{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, {"prob_map": probs_map_1}, []] +TESTS: List[List] = [] +for p in TEST_NDARRAYS: + probs_map_1 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "box_size": 10}, {"prob_map": probs_map_1}, []]) -probs_map_2 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_2[33, 33] = 0.7 -probs_map_2[66, 66] = 0.9 -expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_2 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, - {"prob_map": probs_map_2}, - expected_2, -] - -probs_map_3 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_3[56, 58] = 0.7 -probs_map_3[60, 66] = 0.8 -probs_map_3[66, 66] = 0.9 -expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] -TEST_CASES_2D_3 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, - {"prob_map": probs_map_3}, - expected_3, -] - -probs_map_4 = np.random.rand(100, 100).clip(0, 0.5) -probs_map_4[33, 33] = 0.7 -probs_map_4[66, 66] = 0.9 -expected_4 = [[0.9, 66, 66]] -TEST_CASES_2D_4 = [ - {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, - {"prob_map": probs_map_4}, - expected_4, -] - -probs_map_5 = np.random.rand(100, 100).clip(0, 0.5) -TEST_CASES_2D_5 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_5}, []] + probs_map_2 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_2[33, 33] = 0.7 + probs_map_2[66, 66] = 0.9 + expected_2 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": [10, 10]}, + {"prob_map": probs_map_2}, + expected_2, + ] + ) -probs_map_6 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -TEST_CASES_2D_6 = [{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_6}, []] + probs_map_3 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_3[56, 58] = 0.7 + probs_map_3[60, 66] = 0.8 + probs_map_3[66, 66] = 0.9 + expected_3 = [[0.9, 66, 66], [0.8, 60, 66]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.5, "box_size": (10, 20)}, + {"prob_map": probs_map_3}, + expected_3, + ] + ) -probs_map_7 = torch.as_tensor(np.random.rand(100, 100).clip(0, 0.5)) -probs_map_7[33, 33] = 0.7 -probs_map_7[66, 66] = 0.9 -if torch.cuda.is_available(): - probs_map_7 = probs_map_7.cuda() -expected_7 = [[0.9, 66, 66], [0.7, 33, 33]] -TEST_CASES_2D_7 = [ - {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, - {"prob_map": probs_map_7}, - expected_7, -] + probs_map_4 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_4[33, 33] = 0.7 + probs_map_4[66, 66] = 0.9 + expected_4 = [[0.9, 66, 66]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.8, "box_size": 10}, + {"prob_map": probs_map_4}, + expected_4, + ] + ) -probs_map_3d = torch.rand([50, 50, 50]).uniform_(0, 0.5) -probs_map_3d[25, 25, 25] = 0.7 -probs_map_3d[45, 45, 45] = 0.9 -expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] -TEST_CASES_3D = [ - {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, - {"prob_map": probs_map_3d}, - expected_3d, -] + probs_map_5 = p(np.random.rand(100, 100).clip(0, 0.5)) + TESTS.append([{"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, {"prob_map": probs_map_5}, []]) + probs_map_6 = p(np.random.rand(100, 100).clip(0, 0.5)) + probs_map_6[33, 33] = 0.7 + probs_map_6[66, 66] = 0.9 + expected_6 = [[0.9, 66, 66], [0.7, 33, 33]] + TESTS.append( + [ + {"spatial_dims": 2, "prob_threshold": 0.5, "sigma": 0.1}, + {"prob_map": probs_map_6}, + expected_6, + ] + ) -class TestProbNMS(unittest.TestCase): - @parameterized.expand( + probs_map_3d = p(torch.rand([50, 50, 50]).uniform_(0, 0.5)) + probs_map_3d[25, 25, 25] = 0.7 + probs_map_3d[45, 45, 45] = 0.9 + expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]] + TESTS.append( [ - TEST_CASES_2D_1, - TEST_CASES_2D_2, - TEST_CASES_2D_3, - TEST_CASES_2D_4, - TEST_CASES_2D_5, - TEST_CASES_2D_6, - TEST_CASES_2D_7, - TEST_CASES_3D, + {"spatial_dims": 3, "prob_threshold": 0.5, "box_size": (10, 10, 10)}, + {"prob_map": probs_map_3d}, + expected_3d, ] ) + + +class TestProbNMS(unittest.TestCase): + @parameterized.expand(TESTS) def test_output(self, class_args, probs_map, expected): - nms = ProbNMSD(keys="prob_map", **class_args) + nms = ProbNMSd(keys="prob_map", **class_args) output = nms(probs_map) np.testing.assert_allclose(output["prob_map"], expected) From 199a37486ad0f143c654c235b7b1b7d2c9c4a670 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 15:05:52 +0100 Subject: [PATCH 123/176] save image Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/io/array.py | 4 +-- tests/test_save_image.py | 39 ++++++++++++++++++--------- tests/test_save_imaged.py | 51 ++++++++++++++++++++++-------------- 3 files changed, 60 insertions(+), 34 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 453c709f35..203fea35ba 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -23,7 +23,7 @@ from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.data.nifti_saver import NiftiSaver from monai.data.png_saver import PNGSaver -from monai.transforms.transform import Transform +from monai.transforms.transform import NumpyTransform, TorchTransform, Transform from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode, ensure_tuple, optional_import @@ -183,7 +183,7 @@ def __call__( return img_array, meta_data -class SaveImage(Transform): +class SaveImage(TorchTransform, NumpyTransform): """ Save transformed data into files, support NIfTI and PNG formats. It can work for both numpy array and PyTorch Tensor in both preprocessing transform diff --git a/tests/test_save_image.py b/tests/test_save_image.py index f7c8e07f06..c9ee2585e2 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -17,24 +17,37 @@ from parameterized import parameterized from monai.transforms import SaveImage +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - torch.randint(0, 255, (1, 2, 3, 4)), - {"filename_or_obj": "testfile0.nii.gz"}, - ".nii.gz", - False, -] +_, has_pil = optional_import("PIL") +_, has_nib = optional_import("nibabel") -TEST_CASE_2 = [ - torch.randint(0, 255, (1, 2, 3, 4)), - None, - ".nii.gz", - False, -] +exts = [ext for has_lib, ext in zip((has_nib, has_pil), (".nii.gz", ".png")) if has_lib] + +TESTS = [] +for p in TEST_NDARRAYS: + for ext in exts: + TESTS.append( + [ + p(torch.randint(0, 255, (1, 2, 3, 4))), + {"filename_or_obj": "testfile0" + ext}, + ext, + False, + ] + ) + TESTS.append( + [ + p(torch.randint(0, 255, (1, 2, 3, 4))), + None, + ext, + False, + ] + ) class TestSaveImage(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS, skip_on_empty=True) def test_saved_content(self, test_data, meta_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImage( diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py index 35bbea9628..d512c81f9d 100644 --- a/tests/test_save_imaged.py +++ b/tests/test_save_imaged.py @@ -17,29 +17,42 @@ from parameterized import parameterized from monai.transforms import SaveImaged +from monai.utils.module import optional_import +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - { - "img": torch.randint(0, 255, (1, 2, 3, 4)), - "img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"}, - }, - ".nii.gz", - False, -] - -TEST_CASE_2 = [ - { - "img": torch.randint(0, 255, (1, 2, 3, 4)), - "img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"}, - "patch_index": 6, - }, - ".nii.gz", - False, -] +_, has_pil = optional_import("PIL") +_, has_nib = optional_import("nibabel") + +exts = [ext for has_lib, ext in zip((has_nib, has_pil), (".nii.gz", ".png")) if has_lib] + +TESTS = [] +for p in TEST_NDARRAYS: + for ext in exts: + TESTS.append( + [ + { + "img": p(torch.randint(0, 255, (1, 2, 3, 4))), + "img_meta_dict": {"filename_or_obj": "testfile0" + ext}, + }, + ext, + False, + ] + ) + TESTS.append( + [ + { + "img": p(torch.randint(0, 255, (1, 2, 3, 4))), + "img_meta_dict": {"filename_or_obj": "testfile0" + ext}, + "patch_index": 6, + }, + ext, + False, + ] + ) class TestSaveImaged(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS, skip_on_empty=True) def test_saved_content(self, test_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImaged( From 731fe9f4ba10db53e9b254584fc6368a38a2ba1d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 16:36:40 +0100 Subject: [PATCH 124/176] roc fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/utils.py | 4 + monai/metrics/rocauc.py | 2 +- monai/transforms/post/array.py | 2 +- tests/test_compute_roc_auc.py | 140 ++++++++++++++++++--------------- 4 files changed, 82 insertions(+), 66 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 21c76c1cf1..32db1a7170 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -353,6 +353,10 @@ def decollate_batch(batch, detach: bool = True): if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) + if isinstance(batch, np.ndarray): + if batch.ndim == 0: + return batch + return list(batch) if isinstance(batch, Mapping): _dict_list = {key: decollate_batch(batch[key], detach) for key in batch} return [dict(zip(_dict_list, item)) for item in zip(*_dict_list.values())] diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index 630c47f3a8..565ebbfc75 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -156,6 +156,6 @@ def compute_roc_auc( if average == Average.MACRO: return np.mean(auc_values) if average == Average.WEIGHTED: - weights = [sum(y_) for y_ in y] + weights = [sum(y_.cpu()) for y_ in y] return np.average(auc_values, weights=weights) raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].') diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index d887e6e949..0240087e89 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -185,7 +185,7 @@ def __call__( if threshold_values or self.threshold_values: img_t = img_t >= (logit_thresh or self.logit_thresh) - out, *_ = convert_data_type(img, orig_type, orig_device, dtype=float) + out, *_ = convert_data_type(img_t, orig_type, orig_device, dtype=float) return out diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index 79d62b6436..c4ffc09b47 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -18,73 +18,85 @@ from monai.data import decollate_batch from monai.metrics import ROCAUCMetric, compute_roc_auc from monai.transforms import Activations, AsDiscrete, Compose, ToTensor +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), - torch.tensor([[0], [1], [0], [1]]), - True, - True, - "macro", - 0.75, -] - -TEST_CASE_2 = [ - torch.tensor([[0.5], [0.5], [0.2], [8.3]]), - torch.tensor([[0], [1], [0], [1]]), - False, - False, - "macro", - 0.875, -] - -TEST_CASE_3 = [ - torch.tensor([[0.5], [0.5], [0.2], [8.3]]), - torch.tensor([0, 1, 0, 1]), - False, - False, - "macro", - 0.875, -] - -TEST_CASE_4 = [ - torch.tensor([0.5, 0.5, 0.2, 8.3]), - torch.tensor([0, 1, 0, 1]), - False, - False, - "macro", - 0.875, -] - -TEST_CASE_5 = [ - torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), - torch.tensor([[0], [1], [0], [1]]), - True, - True, - "none", - [0.75, 0.75], -] - -TEST_CASE_6 = [ - torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), - torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), - True, - False, - "weighted", - 0.56667, -] - -TEST_CASE_7 = [ - torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), - torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), - True, - False, - "micro", - 0.62, -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + p(torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]])), + q(torch.tensor([[0], [1], [0], [1]])), + True, + True, + "macro", + 0.75, + ] + ) + TESTS.append( + [ + p(torch.tensor([[0.5], [0.5], [0.2], [8.3]])), + q(torch.tensor([[0], [1], [0], [1]])), + False, + False, + "macro", + 0.875, + ] + ) + TESTS.append( + [ + p(torch.tensor([[0.5], [0.5], [0.2], [8.3]])), + q(torch.tensor([0, 1, 0, 1])), + False, + False, + "macro", + 0.875, + ] + ) + TESTS.append( + [ + p(torch.tensor([0.5, 0.5, 0.2, 8.3])), + q(torch.tensor([0, 1, 0, 1])), + False, + False, + "macro", + 0.875, + ] + ) + TESTS.append( + [ + p(torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]])), + q(torch.tensor([[0], [1], [0], [1]])), + True, + True, + "none", + [0.75, 0.75], + ] + ) + TESTS.append( + [ + p(torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]])), + q(torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]])), + True, + False, + "weighted", + 0.56667, + ] + ) + TESTS.append( + [ + p(torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]])), + q(torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]])), + True, + False, + "micro", + 0.62, + ] + ) class TestComputeROCAUC(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand(TESTS) def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, n_classes=2)]) @@ -93,7 +105,7 @@ def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): result = compute_roc_auc(y_pred=y_pred, y=y, average=average) np.testing.assert_allclose(expected_value, result, rtol=1e-5) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand(TESTS) def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)]) y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot, n_classes=2)]) From e1c10d110ff5a6b9f0afef5a949daa6e64278e77 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 16:55:19 +0100 Subject: [PATCH 125/176] vote ensemble Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/post/array.py | 6 +- tests/test_vote_ensemble.py | 108 +++++++++++++++++------------ tests/test_vote_ensembled.py | 121 +++++++++++++++++++-------------- 3 files changed, 139 insertions(+), 96 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 0240087e89..63a06a4a06 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -395,7 +395,7 @@ def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Te return torch.mean(img_, dim=0) -class VoteEnsemble(Transform): +class VoteEnsemble(TorchTransform): """ Execute vote ensemble on the input data. The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], @@ -419,7 +419,9 @@ def __init__(self, num_classes: Optional[int] = None) -> None: self.num_classes = num_classes def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: - img_ = torch.stack(img) if isinstance(img, (tuple, list)) else torch.as_tensor(img) + img_ = ( + torch.stack([torch.as_tensor(i) for i in img]) if isinstance(img, (tuple, list)) else torch.as_tensor(img) + ) if self.num_classes is not None: has_ch_dim = True if img_.ndimension() > 1 and img_.shape[1] > 1: diff --git a/tests/test_vote_ensemble.py b/tests/test_vote_ensemble.py index 74c19d5f48..7abe120441 100644 --- a/tests/test_vote_ensemble.py +++ b/tests/test_vote_ensemble.py @@ -15,54 +15,74 @@ from parameterized import parameterized from monai.transforms import VoteEnsemble +from tests.utils import TEST_NDARRAYS -# shape: [2, 1, 1] -TEST_CASE_1 = [ - {"num_classes": None}, - [torch.tensor([[[1]], [[0]]]), torch.tensor([[[1]], [[0]]]), torch.tensor([[[0]], [[1]]])], - torch.tensor([[[1.0]], [[0.0]]]), -] - -# shape: [1, 2, 1, 1] -TEST_CASE_2 = [ - {"num_classes": None}, - torch.stack([torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])]), - torch.tensor([[[[1.0]], [[0.0]]]]), -] - -# shape: [1, 2, 1] -TEST_CASE_3 = [ - {"num_classes": 3}, - [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], - torch.tensor([[[0], [2]]]), -] - -# shape: [1, 2, 1] -TEST_CASE_4 = [ - {"num_classes": 5}, - [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], - torch.tensor([[[0], [2]]]), -] - -# shape: [1] -TEST_CASE_5 = [ - {"num_classes": 3}, - [torch.tensor([2]), torch.tensor([2]), torch.tensor([1])], - torch.tensor([2]), -] - -# shape: 1 -TEST_CASE_6 = [ - {"num_classes": 3}, - [torch.tensor(2), torch.tensor(2), torch.tensor(1)], - torch.tensor(2), -] +TESTS = [] +for p in TEST_NDARRAYS: + # shape: [2, 1, 1] + TESTS.append( + [ + p, + {"num_classes": None}, + [torch.tensor([[[1]], [[0]]]), torch.tensor([[[1]], [[0]]]), torch.tensor([[[0]], [[1]]])], + torch.tensor([[[1.0]], [[0.0]]]), + ] + ) + TESTS.append( + # shape: [1, 2, 1, 1] + [ + p, + {"num_classes": None}, + torch.stack( + [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] + ), + torch.tensor([[[[1.0]], [[0.0]]]]), + ] + ) + TESTS.append( + # shape: [1, 2, 1] + [ + p, + {"num_classes": 3}, + [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], + torch.tensor([[[0], [2]]]), + ] + ) + TESTS.append( + # shape: [1, 2, 1] + [ + p, + {"num_classes": 5}, + [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], + torch.tensor([[[0], [2]]]), + ] + ) + TESTS.append( + # shape: [1] + [ + p, + {"num_classes": 3}, + [torch.tensor([2]), torch.tensor([2]), torch.tensor([1])], + torch.tensor([2]), + ] + ) + TESTS.append( + # shape: 1 + [ + p, + {"num_classes": 3}, + [torch.tensor(2), torch.tensor(2), torch.tensor(1)], + torch.tensor(2), + ] + ) class TestVoteEnsemble(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_value(self, input_param, img, expected_value): - result = VoteEnsemble(**input_param)(img) + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, img, expected_value): + result = VoteEnsemble(**input_param)([in_type(i) for i in img]) + if isinstance(result, torch.Tensor): + result = result.cpu() torch.testing.assert_allclose(result, expected_value) def test_cuda_value(self): diff --git a/tests/test_vote_ensembled.py b/tests/test_vote_ensembled.py index e94213733f..ab8daa1e5b 100644 --- a/tests/test_vote_ensembled.py +++ b/tests/test_vote_ensembled.py @@ -15,64 +15,85 @@ from parameterized import parameterized from monai.transforms import VoteEnsembled +from tests.utils import TEST_NDARRAYS -# shape: [1, 2, 1, 1] -TEST_CASE_1 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": None}, - { - "pred0": torch.tensor([[[[1]], [[0]]]]), - "pred1": torch.tensor([[[[1]], [[0]]]]), - "pred2": torch.tensor([[[[0]], [[1]]]]), - }, - torch.tensor([[[[1.0]], [[0.0]]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + # shape: [1, 2, 1, 1] + TESTS.append( + [ + p, + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": None}, + { + "pred0": torch.tensor([[[[1]], [[0]]]]), + "pred1": torch.tensor([[[[1]], [[0]]]]), + "pred2": torch.tensor([[[[0]], [[1]]]]), + }, + torch.tensor([[[[1.0]], [[0.0]]]]), + ] + ) -# shape: [1, 2, 1, 1] -TEST_CASE_2 = [ - {"keys": "output", "output_key": "output", "num_classes": None}, - { - "output": torch.stack( - [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] - ) - }, - torch.tensor([[[[1.0]], [[0.0]]]]), -] + # shape: [1, 2, 1, 1] + TESTS.append( + [ + p, + {"keys": "output", "output_key": "output", "num_classes": None}, + { + "output": torch.stack( + [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])] + ) + }, + torch.tensor([[[[1.0]], [[0.0]]]]), + ] + ) -# shape: [1, 2, 1] -TEST_CASE_3 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, - { - "pred0": torch.tensor([[[0], [2]]]), - "pred1": torch.tensor([[[0], [2]]]), - "pred2": torch.tensor([[[1], [1]]]), - }, - torch.tensor([[[0], [2]]]), -] + # shape: [1, 2, 1] + TESTS.append( + [ + p, + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, + { + "pred0": torch.tensor([[[0], [2]]]), + "pred1": torch.tensor([[[0], [2]]]), + "pred2": torch.tensor([[[1], [1]]]), + }, + torch.tensor([[[0], [2]]]), + ] + ) -# shape: [1, 2, 1] -TEST_CASE_4 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 5}, - { - "pred0": torch.tensor([[[0], [2]]]), - "pred1": torch.tensor([[[0], [2]]]), - "pred2": torch.tensor([[[1], [1]]]), - }, - torch.tensor([[[0], [2]]]), -] + # shape: [1, 2, 1] + TESTS.append( + [ + p, + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 5}, + { + "pred0": torch.tensor([[[0], [2]]]), + "pred1": torch.tensor([[[0], [2]]]), + "pred2": torch.tensor([[[1], [1]]]), + }, + torch.tensor([[[0], [2]]]), + ] + ) -# shape: [1] -TEST_CASE_5 = [ - {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, - {"pred0": torch.tensor([2]), "pred1": torch.tensor([2]), "pred2": torch.tensor([1])}, - torch.tensor([2]), -] + # shape: [1] + TESTS.append( + [ + p, + {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, + {"pred0": torch.tensor([2]), "pred1": torch.tensor([2]), "pred2": torch.tensor([1])}, + torch.tensor([2]), + ] + ) class TestVoteEnsembled(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_value(self, input_param, img, expected_value): - result = VoteEnsembled(**input_param)(img) - torch.testing.assert_allclose(result["output"], expected_value) + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, img, expected_value): + for k, v in img.items(): + img[k] = in_type(v) + result = VoteEnsembled(**input_param)(img)["output"] + result = result.cpu() if isinstance(result, torch.Tensor) else result + torch.testing.assert_allclose(result, expected_value) def test_cuda_value(self): img = torch.stack( From a9a79b3e9594fb4e9cd0c07d5e6658c3cc17e674 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 17:12:11 +0100 Subject: [PATCH 126/176] threshold intensity Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 12 ++++--- tests/test_threshold_intensity.py | 20 +++++++----- tests/test_threshold_intensityd.py | 50 ++++++++++++++++++----------- 3 files changed, 51 insertions(+), 31 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index b9aa45010f..7fe1cc0e8d 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -577,7 +577,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: return img.astype(self.dtype) -class ThresholdIntensity(Transform): +class ThresholdIntensity(TorchTransform, NumpyTransform): """ Filter the intensity values of whole image to below threshold or above threshold. And fill the remaining parts of the image to the `cval` value. @@ -595,13 +595,15 @@ def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> N self.above = above self.cval = cval - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - return np.asarray( - np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval), dtype=img.dtype - ) + where = np.where if isinstance(img, np.ndarray) else torch.where + cval = self.cval if isinstance(img, np.ndarray) else torch.tensor(self.cval, dtype=img.dtype, device=img.device) + res = where(img > self.threshold if self.above else img < self.threshold, img, cval) + res, *_ = convert_data_type(res, dtype=img.dtype) + return res # type: ignore class ScaleIntensityRange(TorchTransform, NumpyTransform): diff --git a/tests/test_threshold_intensity.py b/tests/test_threshold_intensity.py index a6d3895709..094efc1257 100644 --- a/tests/test_threshold_intensity.py +++ b/tests/test_threshold_intensity.py @@ -12,22 +12,26 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ThresholdIntensity +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [{"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)] - -TEST_CASE_2 = [{"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)] - -TEST_CASE_3 = [{"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)]) + TESTS.append([p, {"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)]) + TESTS.append([p, {"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)]) class TestThresholdIntensity(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = np.arange(10) + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = in_type(np.arange(10)) result = ThresholdIntensity(**input_param)(test_data) + if isinstance(result, torch.Tensor): + result = result.cpu() np.testing.assert_allclose(result, expected_value) diff --git a/tests/test_threshold_intensityd.py b/tests/test_threshold_intensityd.py index efcfcfe604..5d8b69a1de 100644 --- a/tests/test_threshold_intensityd.py +++ b/tests/test_threshold_intensityd.py @@ -12,31 +12,45 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ThresholdIntensityd - -TEST_CASE_1 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, - (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), -] - -TEST_CASE_2 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, - (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), -] - -TEST_CASE_3 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, - (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, + (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, + (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, + (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), + ] + ) class TestThresholdIntensityd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = {"image": np.arange(10), "label": np.arange(10), "extra": np.arange(10)} + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = {"image": in_type(np.arange(10)), "label": in_type(np.arange(10)), "extra": in_type(np.arange(10))} result = ThresholdIntensityd(**input_param)(test_data) + for k, v in result.items(): + if isinstance(result[k], torch.Tensor): + result[k] = v.cpu() np.testing.assert_allclose(result["image"], expected_value) np.testing.assert_allclose(result["label"], expected_value) np.testing.assert_allclose(result["extra"], expected_value) From 566127196618b36368a2c8fd8a29a7cfc4a3c568 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 18:51:01 +0100 Subject: [PATCH 127/176] generate_spatial_bounding_box, BoundingRect, BoundingRectd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 6 +- monai/transforms/utils.py | 14 ++- tests/test_bounding_rect.py | 32 +++--- tests/test_bounding_rectd.py | 20 ++-- tests/test_generate_spatial_bounding_box.py | 119 ++++++++++++-------- 5 files changed, 109 insertions(+), 82 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index f028b5c4b4..09b5c62d65 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1078,7 +1078,7 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N return self.padder(self.cropper(img), mode=mode) # type: ignore -class BoundingRect(Transform): +class BoundingRect(TorchTransform, NumpyTransform): """ Compute coordinates of axis-aligned bounding rectangles from input image `img`. The output format of the coordinates is (shape is [channel, 2 * spatial dims]): @@ -1105,7 +1105,7 @@ class BoundingRect(Transform): def __init__(self, select_fn: Callable = is_positive) -> None: self.select_fn = select_fn - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`. """ @@ -1114,5 +1114,5 @@ def __call__(self, img: np.ndarray) -> np.ndarray: for channel in range(img.shape[0]): start_, end_ = generate_spatial_bounding_box(img, select_fn=self.select_fn, channel_indices=channel) bbox.append([i for k in zip(start_, end_) for i in k]) - + bbox = [[j.cpu() if isinstance(j, torch.Tensor) else j for j in i] for i in bbox] # type: ignore return np.stack(bbox, axis=0) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 7a6562f18d..1d2de18083 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -670,7 +670,7 @@ def create_translate(spatial_dims: int, shift: Union[Sequence[float], float]) -> def generate_spatial_bounding_box( - img: np.ndarray, + img: DataObjects.Images, select_fn: Callable = is_positive, channel_indices: Optional[IndexSelection] = None, margin: Union[Sequence[int], int] = 0, @@ -695,7 +695,7 @@ def generate_spatial_bounding_box( margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. """ data = img[list(ensure_tuple(channel_indices))] if channel_indices is not None else img - data = np.any(select_fn(data), axis=0) + data = select_fn(data).any(0) ndim = len(data.shape) margin = ensure_tuple_rep(margin, ndim) for m in margin: @@ -706,13 +706,15 @@ def generate_spatial_bounding_box( box_end = [0] * ndim for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)): - dt = data.any(axis=ax) - if not np.any(dt): + dt = data if len(ax) == 0 else data.any(ax[0]) + if not dt.any(): # if no foreground, return all zero bounding box coords return [0] * ndim, [0] * ndim + dt = dt if isinstance(dt, np.ndarray) else dt.int() + rev_dt = dt[::-1] if isinstance(dt, np.ndarray) else dt.flip(0) - min_d = max(np.argmax(dt) - margin[di], 0) - max_d = max(data.shape[di] - max(np.argmax(dt[::-1]) - margin[di], 0), min_d + 1) + min_d = max(dt.argmax() - margin[di], 0) + max_d = max(data.shape[di] - max(rev_dt.argmax() - margin[di], 0), min_d + 1) box_start[di], box_end[di] = min_d, max_d return box_start, box_end diff --git a/tests/test_bounding_rect.py b/tests/test_bounding_rect.py index 38585cba18..bef97c47b5 100644 --- a/tests/test_bounding_rect.py +++ b/tests/test_bounding_rect.py @@ -12,38 +12,40 @@ import unittest import numpy as np +import torch from parameterized import parameterized import monai from monai.transforms import BoundingRect +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [(2, 3), [[0, 0], [1, 2]]] +SEED = 1 -TEST_CASE_2 = [(1, 8, 10), [[0, 7, 1, 9]]] - -TEST_CASE_3 = [(2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {}, (2, 3), [[0, 0], [1, 2]]]) + TESTS.append([p, {}, (1, 8, 10), [[0, 7, 1, 9]]]) + TESTS.append([p, {}, (2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]]) + TESTS.append([p, {"select_fn": lambda x: x < 1}, (2, 3), [[0, 3], [0, 3]]]) class TestBoundingRect(unittest.TestCase): def setUp(self): - monai.utils.set_determinism(1) + monai.utils.set_determinism(SEED) def tearDown(self): monai.utils.set_determinism(None) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_shape, expected): - test_data = np.random.randint(0, 8, size=input_shape) + @parameterized.expand(TESTS) + def test_result(self, im_type, input_args, input_shape, expected): + np.random.seed(SEED) + test_data = im_type(np.random.randint(0, 8, size=input_shape)) test_data = test_data == 7 - result = BoundingRect()(test_data) + result = BoundingRect(**input_args)(test_data) + if isinstance(result, torch.Tensor): + result = result.cpu() np.testing.assert_allclose(result, expected) - def test_select_fn(self): - test_data = np.random.randint(0, 8, size=(2, 3)) - test_data = test_data == 7 - bbox = BoundingRect(select_fn=lambda x: x < 1)(test_data) - np.testing.assert_allclose(bbox, [[0, 3], [0, 3]]) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_bounding_rectd.py b/tests/test_bounding_rectd.py index 6e725ff583..0bd07ee550 100644 --- a/tests/test_bounding_rectd.py +++ b/tests/test_bounding_rectd.py @@ -16,24 +16,28 @@ import monai from monai.transforms import BoundingRectD +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [(2, 3), [[0, 0], [1, 2]]] +SEED = 1 -TEST_CASE_2 = [(1, 8, 10), [[0, 7, 1, 9]]] - -TEST_CASE_3 = [(2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, (2, 3), [[0, 0], [1, 2]]]) + TESTS.append([p, (1, 8, 10), [[0, 7, 1, 9]]]) + TESTS.append([p, (2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]]) class TestBoundingRectD(unittest.TestCase): def setUp(self): - monai.utils.set_determinism(1) + monai.utils.set_determinism(SEED) def tearDown(self): monai.utils.set_determinism(None) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_shape, expected): - test_data = np.random.randint(0, 8, size=input_shape) + @parameterized.expand(TESTS) + def test_value(self, im_type, input_shape, expected): + np.random.seed(SEED) + test_data = im_type(np.random.randint(0, 8, size=input_shape)) test_data = test_data == 7 result = BoundingRectD("image")({"image": test_data}) np.testing.assert_allclose(result["image_bbox"], expected) diff --git a/tests/test_generate_spatial_bounding_box.py b/tests/test_generate_spatial_bounding_box.py index 32a45d8d1c..d73b9fafcc 100644 --- a/tests/test_generate_spatial_bounding_box.py +++ b/tests/test_generate_spatial_bounding_box.py @@ -15,60 +15,79 @@ from parameterized import parameterized from monai.transforms import generate_spatial_bounding_box +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 0, - }, - ([1, 1], [4, 4]), -] - -TEST_CASE_2 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 1, - "channel_indices": None, - "margin": 0, - }, - ([2, 2], [3, 3]), -] - -TEST_CASE_3 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": 0, - "margin": 0, - }, - ([1, 1], [4, 4]), -] - -TEST_CASE_4 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 1, - }, - ([0, 0], [4, 5]), -] - -TEST_CASE_5 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": [2, 1], - }, - ([0, 0], [5, 5]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + }, + ([1, 1], [4, 4]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 1, + "channel_indices": None, + "margin": 0, + }, + ([2, 2], [3, 3]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": 0, + "margin": 0, + }, + ([1, 1], [4, 4]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 1, + }, + ([0, 0], [4, 5]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + }, + ([0, 0], [5, 5]), + ] + ) class TestGenerateSpatialBoundingBox(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, input_data, expected_box): result = generate_spatial_bounding_box(**input_data) self.assertTupleEqual(result, expected_box) From 317eaf3ce2b0291df0222460ce19b70a5df2c2a6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 19:11:33 +0100 Subject: [PATCH 128/176] transpose, transposed Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 12 ++--- tests/test_transpose.py | 18 ++++--- tests/test_transposed.py | 80 +++++++++++++++++++++---------- 3 files changed, 71 insertions(+), 39 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 765caf8395..729996c0bc 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -368,7 +368,7 @@ def __call__(self, data): return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) -class ToNumpy(NumpyTransform): +class ToNumpy(TorchTransform, NumpyTransform): """ Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor. """ @@ -414,7 +414,7 @@ def __call__(self, img): return pil_image_fromarray(ToNumpy()(img)) -class Transpose(NumpyTransform): +class Transpose(TorchTransform, NumpyTransform): """ Transposes the input image based on the given `indices` dimension ordering. """ @@ -426,11 +426,9 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ - img_np: np.ndarray - img_np, orig_type, orig_device = convert_data_type(img, np.ndarray) # type: ignore - img_np = img_np.transpose(self.indices) # type: ignore - out, *_ = convert_data_type(img_np, orig_type, orig_device) - return out + if isinstance(img, torch.Tensor): + return img.permute(self.indices or tuple(range(img.ndim)[::-1])) + return img.transpose(self.indices) # type: ignore class SqueezeDim(TorchTransform, NumpyTransform): diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 3c30fba281..8602027d87 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -16,20 +16,26 @@ from parameterized import parameterized from monai.transforms import Transpose +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] -for q in (np.arange, torch.arange): - TEST_CASES.append([q(5 * 4).reshape(5, 4), None]) # type: ignore - TEST_CASES.append([q(5 * 4 * 3).reshape(5, 4, 3), [2, 0, 1]]) # type: ignore +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p(np.arange(5 * 4).reshape(5, 4)), None]) + TESTS.append([p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), [2, 0, 1]]) class TestTranspose(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_transpose(self, im, indices): tr = Transpose(indices) out1 = tr(im) - out2 = np.transpose(im, indices) + im_cpu = im if isinstance(im, np.ndarray) else im.cpu() + out2 = np.transpose(im_cpu, indices) + self.assertEqual(type(im), type(out1)) + if isinstance(out1, torch.Tensor): + self.assertEqual(im.device, out1.device) + out1 = out1.cpu() np.testing.assert_array_equal(out1, out2) diff --git a/tests/test_transposed.py b/tests/test_transposed.py index 56375f3981..9998a8c76c 100644 --- a/tests/test_transposed.py +++ b/tests/test_transposed.py @@ -13,44 +13,72 @@ from copy import deepcopy import numpy as np +import torch from parameterized import parameterized from monai.transforms import Transposed +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - np.arange(5 * 4).reshape(5, 4), - [1, 0], -] -TEST_CASE_1 = [ - np.arange(5 * 4).reshape(5, 4), - None, -] -TEST_CASE_2 = [ - np.arange(5 * 4 * 3).reshape(5, 4, 3), - [2, 0, 1], -] -TEST_CASE_3 = [ - np.arange(5 * 4 * 3).reshape(5, 4, 3), - None, -] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3] +KEYS = ("i", "j") + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p(np.arange(5 * 4).reshape(5, 4)), + [1, 0], + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4).reshape(5, 4)), + None, + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), + [2, 0, 1], + ] + ) + TESTS.append( + [ + p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), + None, + ] + ) class TestTranspose(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_transpose(self, im, indices): - data = {"i": deepcopy(im), "j": deepcopy(im)} - tr = Transposed(["i", "j"], indices) + im_cpu = deepcopy(im if isinstance(im, np.ndarray) else im.cpu()) + out_gt = np.transpose(im_cpu, indices) + + data = {k: deepcopy(im) for k in KEYS} + tr = Transposed(KEYS, indices) out_data = tr(data) - out_im1, out_im2 = out_data["i"], out_data["j"] - out_gt = np.transpose(im, indices) - np.testing.assert_array_equal(out_im1, out_gt) - np.testing.assert_array_equal(out_im2, out_gt) + + for k, v in out_data.items(): + if k not in KEYS: + continue + self.assertEqual(type(im), type(v)) + if isinstance(v, torch.Tensor): + self.assertEqual(im.device, v.device) + v = v.cpu() + np.testing.assert_array_equal(v, out_gt) # test inverse fwd_inv_data = tr.inverse(out_data) - for i, j in zip(data.values(), fwd_inv_data.values()): - np.testing.assert_array_equal(i, j) + + for k, v in fwd_inv_data.items(): + if k not in KEYS: + continue + self.assertEqual(type(im), type(v)) + if isinstance(v, torch.Tensor): + self.assertEqual(im.device, v.device) + v = v.cpu() + np.testing.assert_array_equal(v, im_cpu) if __name__ == "__main__": From b9269a6f0c423795c048b6678cf9269c6461005e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 19 Jul 2021 19:12:27 +0100 Subject: [PATCH 129/176] get_extreme Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 1d2de18083..af169ac09f 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -791,8 +791,8 @@ def _get_point(val, dim): points = [] for i in range(img.ndim): - points.append(tuple(_get_point(indices[i][...].min(), i))) - points.append(tuple(_get_point(indices[i][...].max(), i))) + points.append(tuple(_get_point(indices[i].min(), i))) + points.append(tuple(_get_point(indices[i].max(), i))) return points From 1df2dcbe723d5eeb455ab4a4e55754a08244c0c2 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 10:09:34 +0100 Subject: [PATCH 130/176] update gibbs Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_gibbs_noise.py | 7 +++--- tests/test_gibbs_noised.py | 34 +++++++++++++++------------- tests/test_rand_gibbs_noise.py | 4 +--- tests/test_rand_gibbs_noised.py | 39 +++++++++++++++++++-------------- 4 files changed, 46 insertions(+), 38 deletions(-) diff --git a/tests/test_gibbs_noise.py b/tests/test_gibbs_noise.py index a93c72484a..2c5e117eaf 100644 --- a/tests/test_gibbs_noise.py +++ b/tests/test_gibbs_noise.py @@ -22,7 +22,7 @@ from monai.utils.module import optional_import from tests.utils import TEST_NDARRAYS -_, has_torch_fft = optional_import("torch.fft.fftshift") +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -42,8 +42,6 @@ def tearDown(self): def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, num_objs=4, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im - im = create_test_image(*im_shape, 4, 20, 0, 5)[0][None] return input_type(im) @parameterized.expand(TEST_CASES) @@ -53,6 +51,9 @@ def test_same_result(self, im_shape, input_type): t = GibbsNoise(alpha) out1 = t(deepcopy(im)) out2 = t(deepcopy(im)) + self.assertEqual(type(out1), type(im)) + if isinstance(out1, torch.Tensor): + self.assertEqual(out1.device, im.device) torch.testing.assert_allclose(out1, out2, rtol=1e-7, atol=0) self.assertIsInstance(out1, type(im)) diff --git a/tests/test_gibbs_noised.py b/tests/test_gibbs_noised.py index 3044d8f682..f02052818f 100644 --- a/tests/test_gibbs_noised.py +++ b/tests/test_gibbs_noised.py @@ -22,14 +22,12 @@ from monai.utils.module import optional_import from tests.utils import TEST_NDARRAYS -_, has_torch_fft = optional_import("torch.fft.fftshift") +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: - TEST_CASES.append((shape, as_tensor_output, input_type)) - + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) KEYS = ["im", "label"] @@ -45,15 +43,13 @@ def tearDown(self): def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return {k: v for k, v in zip(KEYS, ims)} - return {k: input_type(create_test_image(*im_shape, 4, 20, 0, 5)[0]) for k in KEYS} + return {k: input_type(deepcopy(v)) for k, v in zip(KEYS, ims)} @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, input_type): + def test_same_result(self, im_shape, input_type): data = self.get_data(im_shape, input_type) alpha = 0.8 - t = GibbsNoised(KEYS, alpha, as_tensor_output) + t = GibbsNoised(KEYS, alpha) out1 = t(deepcopy(data)) out2 = t(deepcopy(data)) for k in KEYS: @@ -61,25 +57,33 @@ def test_same_result(self, im_shape, as_tensor_output, input_type): self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, input_type): + def test_identity(self, im_shape, input_type): data = self.get_data(im_shape, input_type) alpha = 0.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) for k in KEYS: - torch.testing.assert_allclose(data[k], out[k], atol=1e-2, rtol=1e-7) + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() + np.testing.assert_allclose(data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, input_type): + def test_alpha_1(self, im_shape, input_type): data = self.get_data(im_shape, input_type) alpha = 1.0 t = GibbsNoised(KEYS, alpha) out = t(deepcopy(data)) for k in KEYS: - torch.testing.assert_allclose(0 * data[k], out[k], rtol=1e-7, atol=0) + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() + np.testing.assert_allclose(0.0 * data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, input_type): + def test_dict_matches(self, im_shape, input_type): data = self.get_data(im_shape, input_type) data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} alpha = 1.0 diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index 4a2496f8da..15cadea0e2 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -22,7 +22,7 @@ from monai.utils.module import optional_import from tests.utils import TEST_NDARRAYS -_, has_torch_fft = optional_import("torch.fft.fftshift") +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): @@ -42,8 +42,6 @@ def tearDown(self): def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im - im = create_test_image(*im_shape, 4, 20, 0, 5)[0][None] return input_type(im) @parameterized.expand(TEST_CASES) diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index cb6e8a9cd0..b8bac67b81 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -22,13 +22,12 @@ from monai.utils.module import optional_import from tests.utils import TEST_NDARRAYS -_, has_torch_fft = optional_import("torch.fft.fftshift") +_, has_torch_fft = optional_import("torch.fft", name="fftshift") TEST_CASES = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: - TEST_CASES.append((shape, as_tensor_output, input_type)) + for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]: + TEST_CASES.append((shape, input_type)) KEYS = ["im", "label"] @@ -45,24 +44,22 @@ def tearDown(self): def get_data(im_shape, input_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims - return {k: v for k, v in zip(KEYS, ims)} - return {k: input_type(create_test_image(*im_shape, 4, 20, 0, 5)[0]) for k in KEYS} + return {k: input_type(v) for k, v in zip(KEYS, ims)} @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, input_type): + def test_0_prob(self, im_shape, input_type): data = self.get_data(im_shape, input_type) alpha = [0.5, 1.0] - t = RandGibbsNoised(KEYS, 0.0, alpha, as_tensor_output) + t = RandGibbsNoised(KEYS, 0.0, alpha) out = t(data) for k in KEYS: torch.testing.assert_allclose(data[k], out[k], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, input_type): + def test_same_result(self, im_shape, input_type): data = self.get_data(im_shape, input_type) alpha = [0.5, 0.8] - t = RandGibbsNoised(KEYS, 1.0, alpha, as_tensor_output) + t = RandGibbsNoised(KEYS, 1.0, alpha) t.set_random_state(42) out1 = t(deepcopy(data)) t.set_random_state(42) @@ -72,25 +69,33 @@ def test_same_result(self, im_shape, as_tensor_output, input_type): self.assertIsInstance(out1[k], type(data[k])) @parameterized.expand(TEST_CASES) - def test_identity(self, im_shape, _, input_type): + def test_identity(self, im_shape, input_type): data = self.get_data(im_shape, input_type) alpha = [0.0, 0.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) for k in KEYS: - torch.testing.assert_allclose(data[k], out[k], atol=1e-2, rtol=1e-7) + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() + np.testing.assert_allclose(data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_alpha_1(self, im_shape, _, input_type): + def test_alpha_1(self, im_shape, input_type): data = self.get_data(im_shape, input_type) alpha = [1.0, 1.0] t = RandGibbsNoised(KEYS, 1.0, alpha) out = t(deepcopy(data)) for k in KEYS: - torch.testing.assert_allclose(0 * data[k], out[k], rtol=1e-7, atol=0) + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k], data[k] = out[k].cpu(), data[k].cpu() + np.testing.assert_allclose(0.0 * data[k], out[k], atol=1e-2) @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, input_type): + def test_dict_matches(self, im_shape, input_type): data = self.get_data(im_shape, input_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} @@ -100,7 +105,7 @@ def test_dict_matches(self, im_shape, _, input_type): torch.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0) @parameterized.expand(TEST_CASES) - def test_alpha(self, im_shape, _, input_type): + def test_alpha(self, im_shape, input_type): data = self.get_data(im_shape, input_type) alpha = [0.5, 0.51] t = RandGibbsNoised(KEYS, 1.0, alpha) From e9b45138f34d4e88a074223b2a1119f99be61105 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 10:45:35 +0100 Subject: [PATCH 131/176] as discrete Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_as_discrete.py | 15 ++++-- tests/test_as_discreted.py | 101 +++++++++++++++++++------------------ 2 files changed, 63 insertions(+), 53 deletions(-) diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index a9f9548177..2ed24f7f32 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -12,6 +12,7 @@ import unittest import torch +import numpy as np from parameterized import parameterized from monai.transforms import AsDiscrete @@ -23,7 +24,7 @@ [ {"argmax": True, "to_onehot": False, "n_classes": None, "threshold_values": False, "logit_thresh": 0.5}, p([[[0.0, 1.0]], [[2.0, 3.0]]]), - p([[[1.0, 1.0]]]), + ([[[1.0, 1.0]]]), (1, 1, 2), ] ) @@ -32,7 +33,7 @@ [ {"argmax": True, "to_onehot": True, "n_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, p([[[0.0, 1.0]], [[2.0, 3.0]]]), - p([[[0.0, 0.0]], [[1.0, 1.0]]]), + ([[[0.0, 0.0]], [[1.0, 1.0]]]), (2, 1, 2), ] ) @@ -41,7 +42,7 @@ [ {"argmax": False, "to_onehot": False, "n_classes": None, "threshold_values": True, "logit_thresh": 0.6}, p([[[0.0, 1.0], [2.0, 3.0]]]), - p([[[0.0, 1.0], [1.0, 1.0]]]), + ([[[0.0, 1.0], [1.0, 1.0]]]), (1, 2, 2), ] ) @@ -50,7 +51,7 @@ [ {"argmax": False, "to_onehot": True, "n_classes": 3}, p(1), - p([0.0, 1.0, 0.0]), + ([0.0, 1.0, 0.0]), (3,), ] ) @@ -60,7 +61,11 @@ class TestAsDiscrete(unittest.TestCase): @parameterized.expand(TESTS) def test_value_shape(self, input_param, img, out, expected_shape): result = AsDiscrete(**input_param)(img) - torch.testing.assert_allclose(result, out) + self.assertEqual(type(img), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(img.device, result.device) + result = result.cpu() + np.testing.assert_allclose(result, out) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index d6a6f3c2a4..f1a98bf837 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -9,65 +9,70 @@ # See the License for the specific language governing permissions and # limitations under the License. +from tests.utils import TEST_NDARRAYS import unittest - +import numpy as np import torch from parameterized import parameterized from monai.transforms import AsDiscreted -TEST_CASE_1 = [ - { - "keys": ["pred", "label"], - "argmax": [True, False], - "to_onehot": True, - "n_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, - {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": torch.tensor([[[0, 1]]])}, - {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]])}, - (2, 1, 2), -] - -TEST_CASE_2 = [ - { - "keys": ["pred", "label"], - "argmax": False, - "to_onehot": False, - "n_classes": None, - "threshold_values": [True, False], - "logit_thresh": 0.6, - }, - {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), "label": torch.tensor([[[0, 1], [1, 1]]])}, - {"pred": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), "label": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]])}, - (1, 2, 2), -] - -TEST_CASE_3 = [ - { - "keys": ["pred"], - "argmax": True, - "to_onehot": True, - "n_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, - {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])}, - {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]])}, - (2, 1, 2), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([ + { + "keys": ["pred", "label"], + "argmax": [True, False], + "to_onehot": True, + "n_classes": 2, + "threshold_values": False, + "logit_thresh": 0.5, + }, + {"pred": p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])), "label": p(torch.tensor([[[0, 1]]]))}, + {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]])}, + (2, 1, 2), + ]) + TESTS.append([ + { + "keys": ["pred", "label"], + "argmax": False, + "to_onehot": False, + "n_classes": None, + "threshold_values": [True, False], + "logit_thresh": 0.6, + }, + {"pred": p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])), "label": p(torch.tensor([[[0, 1], [1, 1]]]))}, + {"pred": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), "label": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), + ]) + TESTS.append([ + { + "keys": ["pred"], + "argmax": True, + "to_onehot": True, + "n_classes": 2, + "threshold_values": False, + "logit_thresh": 0.5, + }, + {"pred": p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]))}, + {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]])}, + (2, 1, 2), + ]) class TestAsDiscreted(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value_shape(self, input_param, test_input, output, expected_shape): result = AsDiscreted(**input_param)(test_input) - torch.testing.assert_allclose(result["pred"], output["pred"]) - self.assertTupleEqual(result["pred"].shape, expected_shape) - if "label" in result: - torch.testing.assert_allclose(result["label"], output["label"]) - self.assertTupleEqual(result["label"].shape, expected_shape) + for k in ("label", "pred"): + if k not in result: + continue + i, r, o = test_input[k], result[k], output[k] + self.assertEqual(type(i), type(r)) + if isinstance(r, torch.Tensor): + self.assertEqual(i.device, r.device) + r = r.cpu() + np.testing.assert_allclose(r, o) if __name__ == "__main__": From 045d41830e743c8e2a0612d1adc5bef7feabde43 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 11:14:23 +0100 Subject: [PATCH 132/176] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_as_discrete.py | 2 +- tests/test_as_discreted.py | 87 ++++++++++++++++++++------------------ 2 files changed, 48 insertions(+), 41 deletions(-) diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index 2ed24f7f32..3bb9702f86 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -11,8 +11,8 @@ import unittest -import torch import numpy as np +import torch from parameterized import parameterized from monai.transforms import AsDiscrete diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index f1a98bf837..d00e6dad24 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -9,55 +9,62 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.utils import TEST_NDARRAYS import unittest + import numpy as np import torch from parameterized import parameterized from monai.transforms import AsDiscreted +from tests.utils import TEST_NDARRAYS TESTS = [] for p in TEST_NDARRAYS: - TESTS.append([ - { - "keys": ["pred", "label"], - "argmax": [True, False], - "to_onehot": True, - "n_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, - {"pred": p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])), "label": p(torch.tensor([[[0, 1]]]))}, - {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]])}, - (2, 1, 2), - ]) - TESTS.append([ - { - "keys": ["pred", "label"], - "argmax": False, - "to_onehot": False, - "n_classes": None, - "threshold_values": [True, False], - "logit_thresh": 0.6, - }, - {"pred": p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])), "label": p(torch.tensor([[[0, 1], [1, 1]]]))}, - {"pred": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), "label": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]])}, - (1, 2, 2), - ]) - TESTS.append([ - { - "keys": ["pred"], - "argmax": True, - "to_onehot": True, - "n_classes": 2, - "threshold_values": False, - "logit_thresh": 0.5, - }, - {"pred": p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]))}, - {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]])}, - (2, 1, 2), - ]) + TESTS.append( + [ + { + "keys": ["pred", "label"], + "argmax": [True, False], + "to_onehot": True, + "n_classes": 2, + "threshold_values": False, + "logit_thresh": 0.5, + }, + {"pred": p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])), "label": p(torch.tensor([[[0, 1]]]))}, + {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]])}, + (2, 1, 2), + ] + ) + TESTS.append( + [ + { + "keys": ["pred", "label"], + "argmax": False, + "to_onehot": False, + "n_classes": None, + "threshold_values": [True, False], + "logit_thresh": 0.6, + }, + {"pred": p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])), "label": p(torch.tensor([[[0, 1], [1, 1]]]))}, + {"pred": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), "label": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), + ] + ) + TESTS.append( + [ + { + "keys": ["pred"], + "argmax": True, + "to_onehot": True, + "n_classes": 2, + "threshold_values": False, + "logit_thresh": 0.5, + }, + {"pred": p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]))}, + {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]])}, + (2, 1, 2), + ] + ) class TestAsDiscreted(unittest.TestCase): From b8bcb94f654734110c0aa17b6298aebfeba673af Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 12:23:15 +0100 Subject: [PATCH 133/176] activations Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/post/array.py | 2 +- tests/test_activations.py | 137 +++++++++++++++++++++------------ tests/test_activationsd.py | 80 +++++++++++-------- 3 files changed, 136 insertions(+), 83 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 63a06a4a06..03041d2d94 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -92,7 +92,7 @@ def __call__( raise TypeError(f"other must be None or callable but is {type(other).__name__}.") img_t: torch.Tensor - img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=float) # type: ignore + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore # convert to float as activation must operate on float tensor if sigmoid or self.sigmoid: diff --git a/tests/test_activations.py b/tests/test_activations.py index 7d8b3e4c38..66867daf3a 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -11,63 +11,100 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.networks.layers.factories import Act from monai.transforms import Activations - -TEST_CASE_1 = [ - {"sigmoid": True, "softmax": False, "other": None}, - torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - torch.tensor([[[0.5000, 0.7311], [0.8808, 0.9526]]]), - (1, 2, 2), -] - -TEST_CASE_2 = [ - {"sigmoid": False, "softmax": True, "other": None}, - torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), - (2, 1, 2), -] - -TEST_CASE_3 = [ - {"sigmoid": False, "softmax": False, "other": torch.tanh}, - torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]]), - (1, 2, 2), -] - -TEST_CASE_4 = [ - "swish", - torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32), - torch.tensor( - [[[-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00]]] - ), - (1, 2, 5), -] - -TEST_CASE_5 = [ - "memswish", - torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32), - torch.tensor( - [[[-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00]]] - ), - (1, 2, 5), -] - -TEST_CASE_6 = [ - "mish", - torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32), - torch.tensor( - [[[-4.54e-04, -2.68e-03, -1.49e-02, -7.26e-02, -2.53e-01], [0.00e00, 1.94e00, 4.00e00, 6.00e00, 8.00e00]]] - ), - (1, 2, 5), -] +from tests.utils import TEST_NDARRAYS + +TESTS1, TESTS2 = [], [] +for p in TEST_NDARRAYS: + TESTS1.append( + [ + {"sigmoid": True, "softmax": False, "other": None}, + p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])), + p(torch.tensor([[[0.5000, 0.7311], [0.8808, 0.9526]]])), + (1, 2, 2), + ] + ) + TESTS1.append( + [ + {"sigmoid": False, "softmax": True, "other": None}, + p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])), + p(torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]])), + (2, 1, 2), + ] + ) + TESTS1.append( + [ + {"sigmoid": False, "softmax": False, "other": torch.tanh}, + p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])), + p(torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]])), + (1, 2, 2), + ] + ) + +for p in TEST_NDARRAYS: + # these tests require torch.Tensor + if p == np.array: + continue + TESTS2.append( + [ + "swish", + p(torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32)), + p( + torch.tensor( + [ + [ + [-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], + [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00], + ] + ] + ) + ), + (1, 2, 5), + ] + ) + TESTS2.append( + [ + "memswish", + p(torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32)), + p( + torch.tensor( + [ + [ + [-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], + [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00], + ] + ] + ) + ), + (1, 2, 5), + ] + ) + TESTS2.append( + [ + "mish", + p(torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32)), + p( + torch.tensor( + [ + [ + [-4.54e-04, -2.68e-03, -1.49e-02, -7.26e-02, -2.53e-01], + [0.00e00, 1.94e00, 4.00e00, 6.00e00, 8.00e00], + ] + ] + ) + ), + (1, 2, 5), + ] + ) class TestActivations(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS1) def test_value_shape(self, input_param, img, out, expected_shape): result = Activations(**input_param)(img) @@ -81,7 +118,7 @@ def _compare(ret, out, shape): else: _compare(result, out, expected_shape) - @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS2) def test_monai_activations_value_shape(self, input_param, img, out, expected_shape): act = Act[input_param]() result = act(img) diff --git a/tests/test_activationsd.py b/tests/test_activationsd.py index 355c50f389..4ed550fabf 100644 --- a/tests/test_activationsd.py +++ b/tests/test_activationsd.py @@ -11,48 +11,64 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.transforms import Activationsd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["pred", "label"], "sigmoid": False, "softmax": [True, False], "other": None}, - {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])}, - { - "pred": torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), - "label": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), - }, - (2, 1, 2), -] - -TEST_CASE_2 = [ - {"keys": ["pred", "label"], "sigmoid": False, "softmax": False, "other": [torch.tanh, None]}, - {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), "label": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])}, - { - "pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]]), - "label": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), - }, - (1, 2, 2), -] - -TEST_CASE_3 = [ - {"keys": "pred", "sigmoid": False, "softmax": False, "other": torch.tanh}, - {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])}, - {"pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]])}, - (1, 2, 2), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["pred", "label"], "sigmoid": False, "softmax": [True, False], "other": None}, + { + "pred": p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])), + "label": p(torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])), + }, + { + "pred": torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), + "label": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), + }, + (2, 1, 2), + ] + ) + TESTS.append( + [ + {"keys": ["pred", "label"], "sigmoid": False, "softmax": False, "other": [torch.tanh, None]}, + {"pred": p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])), "label": p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]))}, + { + "pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]]), + "label": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), + }, + (1, 2, 2), + ] + ) + TESTS.append( + [ + {"keys": "pred", "sigmoid": False, "softmax": False, "other": torch.tanh}, + {"pred": p(torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]))}, + {"pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]])}, + (1, 2, 2), + ] + ) class TestActivationsd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value_shape(self, input_param, test_input, output, expected_shape): result = Activationsd(**input_param)(test_input) - torch.testing.assert_allclose(result["pred"], output["pred"]) - self.assertTupleEqual(result["pred"].shape, expected_shape) - if "label" in result: - torch.testing.assert_allclose(result["label"], output["label"]) - self.assertTupleEqual(result["label"].shape, expected_shape) + for k in ("label", "pred"): + if k not in result: + continue + i, r, o = test_input[k], result[k], output[k] + self.assertEqual(type(i), type(r)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + r = r.cpu() + np.testing.assert_allclose(r, o, rtol=1e-4, atol=1e-5) + self.assertTupleEqual(r.shape, expected_shape) if __name__ == "__main__": From 0aeafc91f408316ff8a7cc717f26b3eae2b8e1a6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 12:38:40 +0100 Subject: [PATCH 134/176] float -> torch.float32 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 2 +- monai/transforms/post/array.py | 2 +- monai/transforms/spatial/array.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 7fe1cc0e8d..ef5d92af7d 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1055,7 +1055,7 @@ def __init__( def __call__(self, img: DataObjects.Images) -> DataObjects.Images: img_t: torch.Tensor - img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=float) # type: ignore + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore gf1, gf2 = [ GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 03041d2d94..14e2c1d318 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -506,7 +506,7 @@ def __call__( prob_map: the input probabilities map, it must have shape (H[, W, ...]). """ prob_map_t: torch.Tensor - prob_map_t, *_ = convert_data_type(deepcopy(prob_map), torch.Tensor, dtype=float) # type: ignore + prob_map_t, *_ = convert_data_type(deepcopy(prob_map), torch.Tensor, dtype=torch.float32) # type: ignore if self.sigma != 0: self.filter.to(prob_map_t) prob_map_t = self.filter(prob_map_t) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index f588a64fd2..b31e59e582 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1234,7 +1234,7 @@ def __call__( raise AssertionError("Error, grid argument must be supplied as an ndarray or tensor ") img_t: torch.Tensor - img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, device=self.device, dtype=float) # type: ignore + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, device=self.device, dtype=torch.float32) # type: ignore grid, *_ = convert_data_type(deepcopy(grid), torch.Tensor, device=img_t.device, dtype=float) if USE_COMPILED: From 3e58f254859a033274bfe11c929b8ae315879a58 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 12:48:03 +0100 Subject: [PATCH 135/176] line Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index b31e59e582..2c5f073b53 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1234,7 +1234,9 @@ def __call__( raise AssertionError("Error, grid argument must be supplied as an ndarray or tensor ") img_t: torch.Tensor - img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, device=self.device, dtype=torch.float32) # type: ignore + img_t, orig_type, orig_device = convert_data_type( # type: ignore + img, torch.Tensor, device=self.device, dtype=torch.float32 + ) grid, *_ = convert_data_type(deepcopy(grid), torch.Tensor, device=img_t.device, dtype=float) if USE_COMPILED: From bc67dbe9b1ea230699dc896a126137532b2d128a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 13:20:07 +0100 Subject: [PATCH 136/176] gaussian noise(d) Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 9 +++-- monai/transforms/intensity/dictionary.py | 11 ++++-- tests/test_rand_gaussian_noise.py | 31 +++++++-------- tests/test_rand_gaussian_noised.py | 48 +++++++++++------------- 4 files changed, 49 insertions(+), 50 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index ef5d92af7d..a615052dff 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -61,7 +61,7 @@ ] -class RandGaussianNoise(RandomizableTransform): +class RandGaussianNoise(TorchTransform, NumpyTransform, RandomizableTransform): """ Add Gaussian noise to image. @@ -86,12 +86,15 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: Apply the transform to `img`. """ self.randomize(img.shape) + if self._noise is None: raise AssertionError if not self._do_transform: return img - dtype = dtype_convert(img.dtype, np.array) - return img + self._noise.astype(dtype) + noise, *_ = convert_data_type( + self._noise, type(img), dtype=img.dtype, device=img.device if isinstance(img, torch.Tensor) else None + ) + return img + noise class RandRicianNoise(TorchTransform, NumpyTransform, RandomizableTransform): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 22525b41b5..d38fa8d857 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -43,7 +43,7 @@ from monai.transforms.transform import MapTransform, RandomizableTransform from monai.utils import ensure_tuple_rep, ensure_tuple_size from monai.utils.enums import DataObjects -from monai.utils.misc import dtype_convert +from monai.utils.misc import convert_data_type __all__ = [ "RandGaussianNoised", @@ -166,8 +166,13 @@ def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: if not self._do_transform: return d for key, noise in self.key_iterator(d, self._noise): - dtype = dtype_convert(d[key].dtype, np.array) - d[key] = d[key] + noise.astype(dtype) + noise, *_ = convert_data_type( + noise, + type(d[key]), + dtype=d[key].dtype, + device=d[key].device if isinstance(d[key], torch.Tensor) else None, + ) + d[key] = d[key] + noise return d diff --git a/tests/test_rand_gaussian_noise.py b/tests/test_rand_gaussian_noise.py index 96f1fa8e6d..d376add460 100644 --- a/tests/test_rand_gaussian_noise.py +++ b/tests/test_rand_gaussian_noise.py @@ -12,35 +12,32 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandGaussianNoise -from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append(("test_zero_mean", p, 0, 0.1)) + TESTS.append(("test_non_zero_mean", p, 1, 0.5)) -class TestRandGaussianNoise(NumpyImageTestCase2D): - @parameterized.expand([("test_zero_mean", 0, 0.1), ("test_non_zero_mean", 1, 0.5)]) - def test_correct_results(self, _, mean, std): - seed = 0 - gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std) - gaussian_fn.set_random_state(seed) - noised = gaussian_fn(self.imt) - np.random.seed(seed) - np.random.random() - expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) - np.testing.assert_allclose(expected, noised, atol=1e-5) - -class TestRandGaussianNoiseTorch(TorchImageTestCase2D): - @parameterized.expand([("test_zero_mean", 0, 0.1), ("test_non_zero_mean", 1, 0.5)]) - def test_correct_results(self, _, mean, std): +class TestRandGaussianNoise(NumpyImageTestCase2D): + @parameterized.expand(TESTS) + def test_correct_results(self, _, im_type, mean, std): seed = 0 gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std) gaussian_fn.set_random_state(seed) - noised = gaussian_fn(self.imt) + im = im_type(self.imt) + noised = gaussian_fn(im) np.random.seed(seed) np.random.random() expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) + self.assertEqual(type(im), type(noised)) + if isinstance(noised, torch.Tensor): + noised = noised.cpu() np.testing.assert_allclose(expected, noised, atol=1e-5) diff --git a/tests/test_rand_gaussian_noised.py b/tests/test_rand_gaussian_noised.py index 442a85ca77..4b0d2a311a 100644 --- a/tests/test_rand_gaussian_noised.py +++ b/tests/test_rand_gaussian_noised.py @@ -12,41 +12,35 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandGaussianNoised -from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D -TEST_CASE_0 = ["test_zero_mean", ["img1", "img2"], 0, 0.1] -TEST_CASE_1 = ["test_non_zero_mean", ["img1", "img2"], 1, 0.5] -TEST_CASES = [TEST_CASE_0, TEST_CASE_1] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1]) + TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5]) seed = 0 -def test_numpy_or_torch(keys, mean, std, imt): - gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std) - gaussian_fn.set_random_state(seed) - noised = gaussian_fn({k: imt for k in keys}) - np.random.seed(seed) - np.random.random() - for k in keys: - expected = imt + np.random.normal(mean, np.random.uniform(0, std), size=imt.shape) - np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) - - -# Test with numpy -class TestRandGaussianNoisedNumpy(NumpyImageTestCase2D): - @parameterized.expand(TEST_CASES) - def test_correct_results(self, _, keys, mean, std): - test_numpy_or_torch(keys, mean, std, self.imt) - - -# Test with torch -class TestRandGaussianNoisedTorch(TorchImageTestCase2D): - @parameterized.expand(TEST_CASES) - def test_correct_results(self, _, keys, mean, std): - test_numpy_or_torch(keys, mean, std, self.imt) +class TestRandGaussianNoised(NumpyImageTestCase2D): + @parameterized.expand(TESTS) + def test_correct_results(self, _, im_type, keys, mean, std): + gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std) + gaussian_fn.set_random_state(seed) + im = im_type(self.imt) + noised = gaussian_fn({k: im for k in keys}) + np.random.seed(seed) + np.random.random() + for k in keys: + expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape) + self.assertEqual(type(im), type(noised[k])) + if isinstance(noised[k], torch.Tensor): + noised[k] = noised[k].cpu() + np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5) if __name__ == "__main__": From c4fa7e81f2d7ed17d5c61549a569cbc0a2e0423f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 13:20:21 +0100 Subject: [PATCH 137/176] update rb script Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/rb_test_transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/rb_test_transforms.py b/tests/rb_test_transforms.py index 798f000345..83eeb73756 100644 --- a/tests/rb_test_transforms.py +++ b/tests/rb_test_transforms.py @@ -35,6 +35,8 @@ def print_colour(t, colour): "LoadImage", "Compose", "RandomizableTransform", + "NumpyTransform", + "TorchTransform", "ToPIL", "ToCupy", ]: From 38c5d48532d4afc806a48fe866babb64bc948560 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 13:36:22 +0100 Subject: [PATCH 138/176] labeltocontour Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/post/array.py | 26 +++++++----- monai/transforms/post/dictionary.py | 2 +- tests/test_label_to_contour.py | 60 +++++++++++++++------------- tests/test_label_to_contourd.py | 62 ++++++++++++++++------------- 4 files changed, 83 insertions(+), 67 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 14e2c1d318..0ca4d8122f 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -303,7 +303,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return out -class LabelToContour(Transform): +class LabelToContour(TorchTransform): """ Return the contour of binary input images that only compose of 0 and 1, with Laplace kernel set as default for edge detection. Typical usage is to plot the edge of label or segmentation output. @@ -321,7 +321,7 @@ def __init__(self, kernel_type: str = "Laplace") -> None: raise NotImplementedError('Currently only kernel_type="Laplace" is supported.') self.kernel_type = kernel_type - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]] @@ -337,22 +337,28 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: ideally the edge should be thin enough, but now it has a thickness. """ - channels = img.shape[0] - img_ = img.unsqueeze(0) - if img.ndimension() == 3: - kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32, device=img.device) + img_t: torch.Tensor + img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore + + channels = img_t.shape[0] + img_ = img_t.unsqueeze(0) + if img_t.ndimension() == 3: + kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32, device=img_t.device) kernel = kernel.repeat(channels, 1, 1, 1) contour_img = F.conv2d(img_, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels) - elif img.ndimension() == 4: - kernel = -1 * torch.ones(3, 3, 3, dtype=torch.float32, device=img.device) + elif img_t.ndimension() == 4: + kernel = -1 * torch.ones(3, 3, 3, dtype=torch.float32, device=img_t.device) kernel[1, 1, 1] = 26 kernel = kernel.repeat(channels, 1, 1, 1, 1) contour_img = F.conv3d(img_, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels) else: - raise ValueError(f"Unsupported img dimension: {img.ndimension()}, available options are [4, 5].") + raise ValueError(f"Unsupported img dimension: {img_t.ndimension()}, available options are [4, 5].") contour_img.clamp_(min=0.0, max=1.0) - return contour_img.squeeze(0) + contour_img = contour_img.squeeze(0) + + out, *_ = convert_data_type(contour_img, orig_type, orig_device) + return out class MeanEnsemble(Transform): diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 7da2b4abee..a3ad56730e 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -225,7 +225,7 @@ def __init__(self, keys: KeysCollection, kernel_type: str = "Laplace", allow_mis super().__init__(keys, allow_missing_keys) self.converter = LabelToContour(kernel_type=kernel_type) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) diff --git a/tests/test_label_to_contour.py b/tests/test_label_to_contour.py index 8f8f3cc054..c34ff2e7e0 100644 --- a/tests/test_label_to_contour.py +++ b/tests/test_label_to_contour.py @@ -15,6 +15,7 @@ import torch from monai.transforms import LabelToContour +from tests.utils import TEST_NDARRAYS expected_output_for_cube = np.array( [ @@ -144,34 +145,37 @@ def gen_fixed_img(): class TestContour(unittest.TestCase): def test_contour(self): input_param = {"kernel_type": "Laplace"} - - # check 5-dim input data - test_cube, expected_output = gen_fixed_cube() - for cube in test_cube: - test_result_cube = LabelToContour(**input_param)(cube) - self.assertEqual(test_result_cube.shape, cube.shape) - - test_result_np = test_result_cube.cpu().numpy() - channels = cube.shape[0] - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) - - # check 4-dim input data - test_img, expected_output = gen_fixed_img() - for img in test_img: - channels = img.shape[0] - test_result_img = LabelToContour(**input_param)(img) - self.assertEqual(test_result_img.shape, img.shape) - - test_result_np = test_result_img.cpu().numpy() - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) - - # check invalid input data - error_input = torch.rand(1, 2) - self.assertRaises(ValueError, LabelToContour(**input_param), error_input) - error_input = torch.rand(1, 2, 3, 4, 5) - self.assertRaises(ValueError, LabelToContour(**input_param), error_input) + for p in TEST_NDARRAYS: + + # check 5-dim input data + test_cube, expected_output = gen_fixed_cube() + for cube in p(test_cube): + test_result_cube = LabelToContour(**input_param)(cube) + self.assertEqual(test_result_cube.shape, cube.shape) + self.assertEqual(type(cube), type(test_result_cube)) + if isinstance(test_result_cube, torch.Tensor): + test_result_cube = test_result_cube.cpu() + channels = cube.shape[0] + for channel in range(channels): + np.testing.assert_allclose(test_result_cube[channel, ...], expected_output) + + # check 4-dim input data + test_img, expected_output = gen_fixed_img() + for img in p(test_img): + channels = img.shape[0] + test_result_img = LabelToContour(**input_param)(img) + self.assertEqual(test_result_img.shape, img.shape) + self.assertEqual(type(test_result_img), type(img)) + if isinstance(test_result_img, torch.Tensor): + test_result_img = test_result_img.cpu() + for channel in range(channels): + np.testing.assert_allclose(test_result_img[channel, ...], expected_output) + + # check invalid input data + error_input = p(torch.rand(1, 2)) + self.assertRaises(ValueError, LabelToContour(**input_param), error_input) + error_input = p(torch.rand(1, 2, 3, 4, 5)) + self.assertRaises(ValueError, LabelToContour(**input_param), error_input) if __name__ == "__main__": diff --git a/tests/test_label_to_contourd.py b/tests/test_label_to_contourd.py index d3795755c7..5b5b8260d1 100644 --- a/tests/test_label_to_contourd.py +++ b/tests/test_label_to_contourd.py @@ -15,6 +15,7 @@ import torch from monai.transforms import LabelToContourd +from tests.utils import TEST_NDARRAYS expected_output_for_cube = np.array( [ @@ -144,34 +145,39 @@ def gen_fixed_img(): class TestContourd(unittest.TestCase): def test_contour(self): input_param = {"keys": "img", "kernel_type": "Laplace"} - - # check 5-dim input data - test_cube, expected_output = gen_fixed_cube() - for cube in test_cube: - test_result_cube = LabelToContourd(**input_param)({"img": cube}) - self.assertEqual(test_result_cube["img"].shape, cube.shape) - - test_result_np = test_result_cube["img"].cpu().numpy() - channels = cube.shape[0] - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) - - # check 4-dim input data - test_img, expected_output = gen_fixed_img() - for img in test_img: - channels = img.shape[0] - test_result_img = LabelToContourd(**input_param)({"img": img}) - self.assertEqual(test_result_img["img"].shape, img.shape) - - test_result_np = test_result_img["img"].cpu().numpy() - for channel in range(channels): - np.testing.assert_allclose(test_result_np[channel, ...], expected_output) - - # check invalid input data - error_input = {"img": torch.rand(1, 2)} - self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) - error_input = {"img": torch.rand(1, 2, 3, 4, 5)} - self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) + for p in TEST_NDARRAYS: + # check 5-dim input data + test_cube, expected_output = gen_fixed_cube() + for cube in p(test_cube): + test_result_cube = LabelToContourd(**input_param)({"img": cube})["img"] + self.assertEqual(test_result_cube.shape, cube.shape) + self.assertEqual(type(test_result_cube), type(cube)) + if isinstance(test_result_cube, torch.Tensor): + self.assertEqual(cube.device, test_result_cube.device) + test_result_cube = test_result_cube.cpu() + + channels = cube.shape[0] + for channel in range(channels): + np.testing.assert_allclose(test_result_cube[channel, ...], expected_output) + + # check 4-dim input data + test_img, expected_output = gen_fixed_img() + for img in p(test_img): + channels = img.shape[0] + test_result_img = LabelToContourd(**input_param)({"img": img})["img"] + self.assertEqual(test_result_img.shape, img.shape) + self.assertEqual(type(test_result_img), type(img)) + if isinstance(test_result_img, torch.Tensor): + test_result_img = test_result_img.cpu() + + for channel in range(channels): + np.testing.assert_allclose(test_result_img[channel, ...], expected_output) + + # check invalid input data + error_input = {"img": p(torch.rand(1, 2))} + self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) + error_input = {"img": p(torch.rand(1, 2, 3, 4, 5))} + self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) if __name__ == "__main__": From da2b3a320b05843157e4443afe380d6b5731ed88 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 15:35:30 +0100 Subject: [PATCH 139/176] normalize intensity Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 7 +- monai/transforms/intensity/array.py | 60 ++++++++++--- tests/test_crop_foreground.py | 14 +-- tests/test_normalize_intensity.py | 134 ++++++++++++++++++---------- tests/test_normalize_intensityd.py | 88 +++++++++++------- 5 files changed, 201 insertions(+), 102 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 09b5c62d65..4e95face8d 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -700,13 +700,14 @@ def __call__(self, img: DataObjects.Images): Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. """ - img_np: np.ndarray - img_np, orig_type, orig_device = convert_data_type(img, np.ndarray) # type: ignore + img_np = img + # img_np: np.ndarray + # img_np, orig_type, orig_device = convert_data_type(img, np.ndarray) # type: ignore box_start, box_end = self.compute_bounding_box(img_np) cropped = self.crop_pad(img_np, box_start, box_end) - cropped, *_ = convert_data_type(cropped, orig_type, orig_device) + # cropped, *_ = convert_data_type(cropped, orig_type, orig_device) if self.return_coords: return cropped, box_start, box_end diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index a615052dff..e4239c3d25 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -508,7 +508,7 @@ def __call__(self, img: np.ndarray): return (img * _bias_fields).astype(self.dtype) -class NormalizeIntensity(Transform): +class NormalizeIntensity(TorchTransform, NumpyTransform): """ Normalize input based on provided args, using calculated mean and std if not provided. This transform can normalize only non-zero values or entire image, and can also calculate @@ -527,8 +527,8 @@ class NormalizeIntensity(Transform): def __init__( self, - subtrahend: Union[Sequence, np.ndarray, None] = None, - divisor: Union[Sequence, np.ndarray, None] = None, + subtrahend: Optional[Union[Sequence, DataObjects.Images]] = None, + divisor: Optional[Union[Sequence, DataObjects.Images]] = None, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32, @@ -539,26 +539,59 @@ def __init__( self.channel_wise = channel_wise self.dtype = dtype - def _normalize(self, img: np.ndarray, sub=None, div=None) -> np.ndarray: - slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=bool) - if not np.any(slices): + @staticmethod + def _mean(x): + if isinstance(x, np.ndarray): + return np.mean(x) + x = torch.mean(x.float()) + return x.item() if x.numel() == 1 else x + + @staticmethod + def _std(x): + if isinstance(x, np.ndarray): + return np.std(x) + x = torch.std(x.float(), unbiased=False) + return x.item() if x.numel() == 1 else x + + def _normalize(self, img: DataObjects.Images, sub=None, div=None) -> DataObjects.Images: + img, *_ = convert_data_type(img, dtype=torch.float32) + + if self.nonzero: + slices = img != 0 + else: + if isinstance(img, np.ndarray): + slices = np.ones_like(img, dtype=bool) + else: + slices = torch.ones_like(img, dtype=torch.bool) + if not slices.any(): return img - _sub = sub if sub is not None else np.mean(img[slices]) - if isinstance(_sub, np.ndarray): + _sub = sub if sub is not None else self._mean(img[slices]) + if isinstance(_sub, (np.ndarray, torch.Tensor)): + _sub, *_ = convert_data_type( + _sub, type(img), dtype=img.dtype, device=img.device if isinstance(img, torch.Tensor) else None + ) _sub = _sub[slices] - _div = div if div is not None else np.std(img[slices]) + _div = div if div is not None else self._std(img[slices]) if np.isscalar(_div): if _div == 0.0: _div = 1.0 - elif isinstance(_div, np.ndarray): + elif isinstance(_div, (np.ndarray, torch.Tensor)): + _div, *_ = convert_data_type( + _div, type(img), dtype=img.dtype, device=img.device if isinstance(img, torch.Tensor) else None + ) _div = _div[slices] _div[_div == 0.0] = 1.0 img[slices] = (img[slices] - _sub) / _div + + if isinstance(img, torch.Tensor): + if img.dtype != torch.float32: + print("hi") + return img - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True, """ @@ -569,7 +602,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.") for i, d in enumerate(img): - img[i] = self._normalize( + img[i] = self._normalize( # type: ignore d, sub=self.subtrahend[i] if self.subtrahend is not None else None, div=self.divisor[i] if self.divisor is not None else None, @@ -577,7 +610,8 @@ def __call__(self, img: np.ndarray) -> np.ndarray: else: img = self._normalize(img, self.subtrahend, self.divisor) - return img.astype(self.dtype) + out, *_ = convert_data_type(img, dtype=self.dtype) + return out class ThresholdIntensity(TorchTransform, NumpyTransform): diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 7235527e3b..766147ed72 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -24,7 +24,7 @@ TESTS.append( [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, - p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), # type: ignore + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), ] ) @@ -32,7 +32,7 @@ TESTS.append( [ {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, - p([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), # type: ignore + p([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), p([[[3]]]), ] ) @@ -40,7 +40,7 @@ TESTS.append( [ {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, - p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), # type: ignore + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), ] ) @@ -48,7 +48,7 @@ TESTS.append( [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, - p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), # type: ignore + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), ] ) @@ -56,7 +56,7 @@ TESTS.append( [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, - p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), # type: ignore + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), ] ) @@ -64,7 +64,7 @@ TESTS.append( [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, - p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), # type: ignore + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), p([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), ] ) @@ -72,7 +72,7 @@ TESTS.append( [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10}, - p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), # type: ignore + p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), p(np.zeros((1, 0, 0), dtype=np.int64)), ] ) diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index 9d474faea7..2d26eb6dd3 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -12,70 +12,108 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import NormalizeIntensity -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D -TEST_CASES = [ - [{"nonzero": True}, np.array([0.0, 3.0, 0.0, 4.0]), np.array([0.0, -1.0, 0.0, 1.0])], - [ - {"subtrahend": np.array([3.5, 3.5, 3.5, 3.5]), "divisor": np.array([0.5, 0.5, 0.5, 0.5]), "nonzero": True}, - np.array([0.0, 3.0, 0.0, 4.0]), - np.array([0.0, -1.0, 0.0, 1.0]), - ], - [{"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])], - [{"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])], - [{"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])], - [ - {"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]}, - np.ones((3, 2, 2)), - np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]), - ], - [ - {"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]}, - np.ones((3, 2, 2)), - np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]), - ], - [ - {"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * -1.0, - ], - [ - {"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * 0.5, - ], - [ - {"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]}, - np.ones((3, 2, 2)), - np.ones((3, 2, 2)) * 0.5, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"nonzero": True}, np.array([0.0, 3.0, 0.0, 4.0]), np.array([0.0, -1.0, 0.0, 1.0])]) + TESTS.append( + [ + p, + {"subtrahend": np.array([3.5, 3.5, 3.5, 3.5]), "divisor": np.array([0.5, 0.5, 0.5, 0.5]), "nonzero": True}, + np.array([0.0, 3.0, 0.0, 4.0]), + np.array([0.0, -1.0, 0.0, 1.0]), + ] + ) + TESTS.append([p, {"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])]) + TESTS.append([p, {"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])]) + TESTS.append([p, {"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])]) + TESTS.append( + [ + p, + {"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]}, + np.ones((3, 2, 2)), + np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]), + ] + ) + TESTS.append( + [ + p, + {"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]}, + np.ones((3, 2, 2)), + np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]), + ] + ) + TESTS.append( + [ + p, + {"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0}, + np.ones((3, 2, 2)), + np.ones((3, 2, 2)) * -1.0, + ] + ) + TESTS.append( + [ + p, + {"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0}, + np.ones((3, 2, 2)), + np.ones((3, 2, 2)) * 0.5, + ] + ) + TESTS.append( + [ + p, + {"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]}, + np.ones((3, 2, 2)), + np.ones((3, 2, 2)) * 0.5, + ] + ) class TestNormalizeIntensity(NumpyImageTestCase2D): - def test_default(self): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_default(self, im_type): + im = im_type(self.imt.copy()) normalizer = NormalizeIntensity() - normalized = normalizer(self.imt.copy()) - self.assertTrue(normalized.dtype == np.float32) + normalized = normalizer(im) + self.assertEqual(type(im), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(im.device, normalized.device) + normalized = normalized.cpu() + self.assertTrue(normalized.dtype in (np.float32, torch.float32)) expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) np.testing.assert_allclose(normalized, expected, rtol=1e-3) - @parameterized.expand(TEST_CASES) - def test_nonzero(self, input_param, input_data, expected_data): + @parameterized.expand(TESTS) + def test_nonzero(self, in_type, input_param, input_data, expected_data): normalizer = NormalizeIntensity(**input_param) - np.testing.assert_allclose(expected_data, normalizer(input_data)) + im = in_type(input_data) + normalized = normalizer(im) + self.assertEqual(type(im), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(im.device, normalized.device) + normalized = normalized.cpu() + np.testing.assert_allclose(expected_data, normalized) - def test_channel_wise(self): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, im_type): normalizer = NormalizeIntensity(nonzero=True, channel_wise=True) - input_data = np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]) + input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])) expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) - np.testing.assert_allclose(expected, normalizer(input_data)) + normalized = normalizer(input_data) + self.assertEqual(type(input_data), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(input_data.device, normalized.device) + normalized = normalized.cpu() + np.testing.assert_allclose(expected, normalized) - def test_value_errors(self): - input_data = np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]) + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_value_errors(self, im_type): + input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])) normalizer = NormalizeIntensity(nonzero=True, channel_wise=True, subtrahend=[1]) with self.assertRaises(ValueError): normalizer(input_data) diff --git a/tests/test_normalize_intensityd.py b/tests/test_normalize_intensityd.py index 482d1a3f5b..2a45260b65 100644 --- a/tests/test_normalize_intensityd.py +++ b/tests/test_normalize_intensityd.py @@ -12,54 +12,80 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import NormalizeIntensityd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D -TEST_CASE_1 = [ - {"keys": ["img"], "nonzero": True}, - {"img": np.array([0.0, 3.0, 0.0, 4.0])}, - np.array([0.0, -1.0, 0.0, 1.0]), -] - -TEST_CASE_2 = [ - { - "keys": ["img"], - "subtrahend": np.array([3.5, 3.5, 3.5, 3.5]), - "divisor": np.array([0.5, 0.5, 0.5, 0.5]), - "nonzero": True, - }, - {"img": np.array([0.0, 3.0, 0.0, 4.0])}, - np.array([0.0, -1.0, 0.0, 1.0]), -] - -TEST_CASE_3 = [ - {"keys": ["img"], "nonzero": True}, - {"img": np.array([0.0, 0.0, 0.0, 0.0])}, - np.array([0.0, 0.0, 0.0, 0.0]), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img"], "nonzero": True}, + {"img": p(np.array([0.0, 3.0, 0.0, 4.0]))}, + np.array([0.0, -1.0, 0.0, 1.0]), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "subtrahend": q(np.array([3.5, 3.5, 3.5, 3.5])), + "divisor": q(np.array([0.5, 0.5, 0.5, 0.5])), + "nonzero": True, + }, + {"img": p(np.array([0.0, 3.0, 0.0, 4.0]))}, + np.array([0.0, -1.0, 0.0, 1.0]), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "nonzero": True}, + {"img": p(np.array([0.0, 0.0, 0.0, 0.0]))}, + np.array([0.0, 0.0, 0.0, 0.0]), + ] + ) class TestNormalizeIntensityd(NumpyImageTestCase2D): - def test_image_normalize_intensityd(self): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_image_normalize_intensityd(self, im_type): key = "img" + im = im_type(self.imt) normalizer = NormalizeIntensityd(keys=[key]) - normalized = normalizer({key: self.imt}) + normalized = normalizer({key: im})[key] expected = (self.imt - np.mean(self.imt)) / np.std(self.imt) - np.testing.assert_allclose(normalized[key], expected, rtol=1e-3) + self.assertEqual(type(im), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(im.device, normalized.device) + normalized = normalized.cpu() + np.testing.assert_allclose(normalized, expected, rtol=1e-3) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_nonzero(self, input_param, input_data, expected_data): + key = "img" normalizer = NormalizeIntensityd(**input_param) - np.testing.assert_allclose(expected_data, normalizer(input_data)["img"]) + normalized = normalizer(input_data)[key] + self.assertEqual(type(input_data[key]), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(input_data[key].device, normalized.device) + normalized = normalized.cpu() + np.testing.assert_allclose(normalized, expected_data) - def test_channel_wise(self): + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, im_type): key = "img" normalizer = NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True) - input_data = {key: np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])} + input_data = {key: im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))} + normalized = normalizer(input_data)[key] + self.assertEqual(type(input_data[key]), type(normalized)) + if isinstance(normalized, torch.Tensor): + self.assertEqual(input_data[key].device, normalized.device) + normalized = normalized.cpu() expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) - np.testing.assert_allclose(expected, normalizer(input_data)[key]) + np.testing.assert_allclose(expected, normalized) if __name__ == "__main__": From b50fcce7fb09a6826bbd5a7209616a8a588e1be6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 16:43:20 +0100 Subject: [PATCH 140/176] crop foreground Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 22 ++-- tests/test_crop_foreground.py | 8 +- tests/test_crop_foregroundd.py | 165 +++++++++++++++++++----------- 3 files changed, 118 insertions(+), 77 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 4e95face8d..a0ef82cd7f 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -602,7 +602,7 @@ def __call__(self, img: np.ndarray) -> List[np.ndarray]: return [self.cropper(img) for _ in range(self.num_samples)] -class CropForeground(NumpyTransform): +class CropForeground(TorchTransform, NumpyTransform): """ Crop an image using a bounding box. The bounding box is generated by selecting foreground using select_fn at channels channel_indices. margin is added in each spatial dimension of the bounding box. @@ -671,9 +671,9 @@ def compute_bounding_box(self, img: DataObjects.Images) -> Tuple[np.ndarray, np. And adjust bounding box coords to be divisible by `k`. """ - img_np: np.ndarray - img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore - box_start, box_end = generate_spatial_bounding_box(img_np, self.select_fn, self.channel_indices, self.margin) + box_start, box_end = generate_spatial_bounding_box(img, self.select_fn, self.channel_indices, self.margin) + box_start = [i.cpu() if isinstance(i, torch.Tensor) else i for i in box_start] # type: ignore + box_end = [i.cpu() if isinstance(i, torch.Tensor) else i for i in box_end] # type: ignore box_start_ = np.asarray(box_start, dtype=np.int16) box_end_ = np.asarray(box_end, dtype=np.int16) orig_spatial_size = box_end_ - box_start_ @@ -684,12 +684,12 @@ def compute_bounding_box(self, img: DataObjects.Images) -> Tuple[np.ndarray, np. box_end_ = box_start_ + spatial_size return box_start_, box_end_ - def crop_pad(self, img: np.ndarray, box_start: np.ndarray, box_end: np.ndarray): + def crop_pad(self, img: DataObjects.Images, box_start: np.ndarray, box_end: np.ndarray) -> DataObjects.Images: """ Crop and pad based on the bounding box. """ - cropped: np.ndarray = SpatialCrop(roi_start=box_start, roi_end=box_end)(img) # type: ignore + cropped = SpatialCrop(roi_start=box_start, roi_end=box_end)(img) pad_to_start = np.maximum(-box_start, 0) pad_to_end = np.maximum(box_end - np.asarray(img.shape[1:]), 0) pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) @@ -700,14 +700,8 @@ def __call__(self, img: DataObjects.Images): Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. """ - img_np = img - # img_np: np.ndarray - # img_np, orig_type, orig_device = convert_data_type(img, np.ndarray) # type: ignore - - box_start, box_end = self.compute_bounding_box(img_np) - cropped = self.crop_pad(img_np, box_start, box_end) - - # cropped, *_ = convert_data_type(cropped, orig_type, orig_device) + box_start, box_end = self.compute_bounding_box(img) + cropped = self.crop_pad(img, box_start, box_end) if self.return_coords: return cropped, box_start, box_end diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 766147ed72..0bae1f90f3 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -18,10 +18,10 @@ from monai.transforms import CropForeground from tests.utils import TEST_NDARRAYS -TESTS = [] +TEST_COORDS, TESTS = [], [] for p in TEST_NDARRAYS: - TESTS.append( + TEST_COORDS.append( [ {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), @@ -79,12 +79,12 @@ class TestCropForeground(unittest.TestCase): - @parameterized.expand(TESTS) + @parameterized.expand(TEST_COORDS + TESTS) def test_value(self, argments, image, expected_data): result = CropForeground(**argments)(image) torch.testing.assert_allclose(result, expected_data, rtol=1e-7, atol=0) - @parameterized.expand([TESTS[0]]) + @parameterized.expand(TEST_COORDS) def test_return_coords(self, argments, image, _): argments["return_coords"] = True _, start_coord, end_coord = CropForeground(**argments)(image) diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index 37abfb8c55..14367ebfd7 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -12,79 +12,126 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import CropForegroundd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - { - "keys": ["img", "label"], - "source_key": "label", - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 0, - }, - { - "img": np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]), - "label": np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]), - }, - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] +TEST_POSITION, TESTS = [], [] +for p in TEST_NDARRAYS: -TEST_CASE_2 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[3]]]), -] - -TEST_CASE_3 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_4 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_5 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_6 = [ - { - "keys": ["img"], - "source_key": "img", - "select_fn": lambda x: x > 0, - "channel_indices": 0, - "margin": 0, - "k_divisible": [4, 6], - "mode": "edge", - }, - {"img": np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]])}, - np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 1], [2, 2, 3, 2, 2, 2], [1, 1, 2, 1, 1, 1]]]), -] + TEST_POSITION.append( + [ + { + "keys": ["img", "label"], + "source_key": "label", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + }, + { + "img": p( + np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]) + ), + "label": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]) + ), + }, + np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + np.array([[[3]]]), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ) + }, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "source_key": "img", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + }, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ) + }, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "source_key": "img", + "select_fn": lambda x: x > 0, + "channel_indices": 0, + "margin": 0, + "k_divisible": [4, 6], + "mode": "edge", + }, + { + "img": p( + np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]]) + ) + }, + np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 1], [2, 2, 3, 2, 2, 2], [1, 1, 2, 1, 1, 1]]]), + ] + ) class TestCropForegroundd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_value(self, argments, image, expected_data): - result = CropForegroundd(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data) + @parameterized.expand(TEST_POSITION + TESTS) + def test_value(self, argments, input_data, expected_data): + result = CropForegroundd(**argments)(input_data) + r, i = result["img"], input_data["img"] + self.assertEqual(type(r), type(i)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + r = r.cpu() + np.testing.assert_allclose(r, expected_data) - @parameterized.expand([TEST_CASE_1]) - def test_foreground_position(self, argments, image, _): - result = CropForegroundd(**argments)(image) + @parameterized.expand(TEST_POSITION) + def test_foreground_position(self, argments, input_data, _): + result = CropForegroundd(**argments)(input_data) np.testing.assert_allclose(result["foreground_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["foreground_end_coord"], np.array([4, 4])) argments["start_coord_key"] = "test_start_coord" argments["end_coord_key"] = "test_end_coord" - result = CropForegroundd(**argments)(image) + result = CropForegroundd(**argments)(input_data) np.testing.assert_allclose(result["test_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["test_end_coord"], np.array([4, 4])) From 73e3e7078dd46178696c1ff2da8b2c9014c6f2f6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 17:19:29 +0100 Subject: [PATCH 141/176] LabelToMask[d] Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 22 ++++++--- tests/test_label_to_mask.py | 74 ++++++++++++++++------------ tests/test_label_to_maskd.py | 81 ++++++++++++++++++------------- 3 files changed, 108 insertions(+), 69 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 729996c0bc..cf2fd8f817 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -634,7 +634,7 @@ def __call__(self, img: DataObjects.Images, func: Optional[Callable] = None): raise ValueError("Incompatible values: func=None and self.func=None.") -class LabelToMask(NumpyTransform): +class LabelToMask(TorchTransform, NumpyTransform): """ Convert labels to mask for other tasks. A typical usage is to convert segmentation labels to mask data to pre-process images and then feed the images into classification network. @@ -661,6 +661,13 @@ def __init__( # pytype: disable=annotation-type-mismatch self.select_labels = ensure_tuple(select_labels) self.merge_channels = merge_channels + @staticmethod + def _in1d(x, y): + if isinstance(x, np.ndarray): + return np.in1d(x, y) + else: + return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1) + def __call__( self, img: DataObjects.Images, @@ -675,7 +682,6 @@ def __call__( merge_channels: whether to use `np.any()` to merge the result on channel dim. if yes, will return a single channel mask with binary data. """ - img, orig_type, orig_device = convert_data_type(img, np.ndarray) if select_labels is None: select_labels = self.select_labels @@ -685,11 +691,15 @@ def __call__( if img.shape[0] > 1: data = img[[*select_labels]] else: - data = np.where(np.in1d(img, select_labels), True, False).reshape(img.shape) + where = np.where if isinstance(img, np.ndarray) else torch.where + data = where(self._in1d(img, select_labels), True, False).reshape(img.shape) - out_np = np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data - out, *_ = convert_data_type(out_np, orig_type, orig_device) - return out + if merge_channels or self.merge_channels: + if isinstance(data, np.ndarray): + return np.any(data, axis=0, keepdims=True) # type: ignore + else: + return torch.any(data, dim=0, keepdim=True) + return data class FgBgToIndices(Transform): diff --git a/tests/test_label_to_mask.py b/tests/test_label_to_mask.py index 2a84c7bea6..dc785ba24e 100644 --- a/tests/test_label_to_mask.py +++ b/tests/test_label_to_mask.py @@ -12,45 +12,59 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import LabelToMask +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"select_labels": [2, 3], "merge_channels": False}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_2 = [ - {"select_labels": 2, "merge_channels": False}, - np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]), - np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_3 = [ - {"select_labels": [1, 2], "merge_channels": False}, - np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_4 = [ - {"select_labels": 2, "merge_channels": False}, - np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_5 = [ - {"select_labels": [1, 2], "merge_channels": True}, - np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 0, 1], [1, 1, 1]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"select_labels": [2, 3], "merge_channels": False}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": 2, "merge_channels": False}, + p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])), + np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": [1, 2], "merge_channels": False}, + p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])), + np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": 2, "merge_channels": False}, + p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])), + np.array([[[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"select_labels": [1, 2], "merge_channels": True}, + p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])), + np.array([[[1, 0, 1], [1, 1, 1]]]), + ] + ) class TestLabelToMask(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = LabelToMask(**argments)(image) + self.assertEqual(type(result), type(image)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, image.device) + result = result.cpu() np.testing.assert_allclose(result, expected_data) diff --git a/tests/test_label_to_maskd.py b/tests/test_label_to_maskd.py index f046390c19..59e4ddb047 100644 --- a/tests/test_label_to_maskd.py +++ b/tests/test_label_to_maskd.py @@ -12,46 +12,61 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import LabelToMaskd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": "img", "select_labels": [2, 3], "merge_channels": False}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_2 = [ - {"keys": "img", "select_labels": 2, "merge_channels": False}, - {"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])}, - np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), -] - -TEST_CASE_3 = [ - {"keys": "img", "select_labels": [1, 2], "merge_channels": False}, - {"img": np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])}, - np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_4 = [ - {"keys": "img", "select_labels": 2, "merge_channels": False}, - {"img": np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])}, - np.array([[[1, 0, 1], [1, 1, 0]]]), -] - -TEST_CASE_5 = [ - {"keys": "img", "select_labels": [1, 2], "merge_channels": True}, - {"img": np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])}, - np.array([[[1, 0, 1], [1, 1, 1]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": "img", "select_labels": [2, 3], "merge_channels": False}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": 2, "merge_channels": False}, + {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]))}, + np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": [1, 2], "merge_channels": False}, + {"img": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))}, + np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": 2, "merge_channels": False}, + {"img": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))}, + np.array([[[1, 0, 1], [1, 1, 0]]]), + ] + ) + TESTS.append( + [ + {"keys": "img", "select_labels": [1, 2], "merge_channels": True}, + {"img": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))}, + np.array([[[1, 0, 1], [1, 1, 1]]]), + ] + ) class TestLabelToMaskd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_value(self, argments, image, expected_data): - result = LabelToMaskd(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data) + @parameterized.expand(TESTS) + def test_value(self, argments, input_data, expected_data): + result = LabelToMaskd(**argments)(input_data) + r, i = result["img"], input_data["img"] + self.assertEqual(type(r), type(i)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + r = r.cpu() + np.testing.assert_allclose(r, expected_data) if __name__ == "__main__": From e1d81edea0c68a13aad23530fdd91756b82eca48 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 20 Jul 2021 18:34:56 +0100 Subject: [PATCH 142/176] RandCropByPosNegLabel Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 34 +++--- monai/transforms/utility/array.py | 2 +- monai/transforms/utils.py | 31 ++++-- tests/test_rand_crop_by_pos_neg_label.py | 129 +++++++++++++--------- tests/test_rand_crop_by_pos_neg_labeld.py | 41 ++++--- 5 files changed, 140 insertions(+), 97 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index a0ef82cd7f..183369a8f3 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -760,7 +760,7 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> return results -class RandCropByPosNegLabel(Randomizable, Transform): +class RandCropByPosNegLabel(Randomizable, TorchTransform, NumpyTransform): """ Crop random fixed sized regions with the center being a foreground or background voxel based on the Pos Neg Ratio. @@ -813,14 +813,14 @@ class RandCropByPosNegLabel(Randomizable, Transform): def __init__( self, spatial_size: Union[Sequence[int], int], - label: Optional[np.ndarray] = None, + label: Optional[DataObjects.Images] = None, pos: float = 1.0, neg: float = 1.0, num_samples: int = 1, - image: Optional[np.ndarray] = None, + image: Optional[DataObjects.Images] = None, image_threshold: float = 0.0, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, + fg_indices: Optional[DataObjects.Images] = None, + bg_indices: Optional[DataObjects.Images] = None, ) -> None: self.spatial_size = ensure_tuple(spatial_size) self.label = label @@ -838,10 +838,10 @@ def __init__( def randomize( self, - label: np.ndarray, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, + label: DataObjects.Images, + fg_indices: Optional[DataObjects.Images] = None, + bg_indices: Optional[DataObjects.Images] = None, + image: Optional[DataObjects.Images] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) if fg_indices is None or bg_indices is None: @@ -859,12 +859,12 @@ def randomize( def __call__( self, - img: np.ndarray, - label: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, - fg_indices: Optional[np.ndarray] = None, - bg_indices: Optional[np.ndarray] = None, - ) -> List[np.ndarray]: + img: DataObjects.Images, + label: Optional[DataObjects.Images] = None, + image: Optional[DataObjects.Images] = None, + fg_indices: Optional[DataObjects.Images] = None, + bg_indices: Optional[DataObjects.Images] = None, + ) -> List[DataObjects.Images]: """ Args: img: input data to crop samples from based on the pos/neg ratio of `label` and `image`. @@ -887,11 +887,11 @@ def __call__( image = self.image self.randomize(label, fg_indices, bg_indices, image) - results: List[np.ndarray] = [] + results: List[DataObjects.Images] = [] if self.centers is not None: for center in self.centers: cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore - results.append(cropper(img)) # type: ignore + results.append(cropper(img)) return results diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index cf2fd8f817..b45fbec0b1 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -725,7 +725,7 @@ def __call__( label: np.ndarray, image: Optional[np.ndarray] = None, output_shape: Optional[Sequence[int]] = None, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[DataObjects.Images, DataObjects.Images]: """ Args: label: input data to compute foreground and background indices. diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index af169ac09f..b80e121fca 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -251,10 +251,10 @@ def resize_center(img: np.ndarray, *resize_dims: Optional[int], fill_value: floa def map_binary_to_indices( - label: np.ndarray, - image: Optional[np.ndarray] = None, + label: DataObjects.Images, + image: Optional[DataObjects.Images] = None, image_threshold: float = 0.0, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[DataObjects.Images, DataObjects.Images]: """ Compute the foreground and background of input label data, return the indices after fattening. For example: @@ -269,16 +269,22 @@ def map_binary_to_indices( determine the valid image content area and select background only in this area. """ + + def _nonzero(x): + if isinstance(x, np.ndarray): + return np.nonzero(x)[0] + return torch.nonzero(x).flatten() + # Prepare fg/bg indices if label.shape[0] > 1: label = label[1:] # for One-Hot format data, remove the background channel - label_flat = np.any(label, axis=0).ravel() # in case label has multiple dimensions - fg_indices = np.nonzero(label_flat)[0] + label_flat = label.any(0).ravel() # in case label has multiple dimensions + fg_indices = _nonzero(label_flat) if image is not None: - img_flat = np.any(image > image_threshold, axis=0).ravel() - bg_indices = np.nonzero(np.logical_and(img_flat, ~label_flat))[0] + img_flat = (image > image_threshold).any(0).ravel() + bg_indices = _nonzero(img_flat & ~label_flat) else: - bg_indices = np.nonzero(~label_flat)[0] + bg_indices = _nonzero(~label_flat) return fg_indices, bg_indices @@ -424,8 +430,8 @@ def generate_pos_neg_label_crop_centers( num_samples: int, pos_ratio: float, label_spatial_shape: Sequence[int], - fg_indices: np.ndarray, - bg_indices: np.ndarray, + fg_indices: DataObjects.Images, + bg_indices: DataObjects.Images, rand_state: Optional[np.random.RandomState] = None, ) -> List[List[np.ndarray]]: """ @@ -450,7 +456,6 @@ def generate_pos_neg_label_crop_centers( rand_state = np.random.random.__self__ # type: ignore centers = [] - fg_indices, bg_indices = np.asarray(fg_indices), np.asarray(bg_indices) if fg_indices.size == 0 and bg_indices.size == 0: raise ValueError("No sampling location available.") @@ -464,7 +469,9 @@ def generate_pos_neg_label_crop_centers( for _ in range(num_samples): indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices random_int = rand_state.randint(len(indices_to_use)) - center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) + idx = indices_to_use[random_int] + idx = idx.cpu() if isinstance(idx, torch.Tensor) else idx + center = np.unravel_index(idx, label_spatial_shape) # shift center to range of valid centers center_ori = list(center) centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py index e0f669ab3f..a81976dea1 100644 --- a/tests/test_rand_crop_by_pos_neg_label.py +++ b/tests/test_rand_crop_by_pos_neg_label.py @@ -10,68 +10,93 @@ # limitations under the License. import unittest +from copy import deepcopy import numpy as np from parameterized import parameterized from monai.transforms import RandCropByPosNegLabel +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "spatial_size": [2, 2, -1], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 3), -] +TESTS = [] +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [2, 2, -1], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 2, 2, 3), + ] +) +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 2, 2, 2), + ] +) +TESTS.append( + [ + { + "label": None, + "spatial_size": [2, 2, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + { + "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + }, + (3, 2, 2, 2), + ] +) -TEST_CASE_1 = [ - { - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, - list, - (3, 2, 2, 2), -] -TEST_CASE_2 = [ - { - "label": None, - "spatial_size": [2, 2, 2], - "pos": 1, - "neg": 1, - "num_samples": 2, - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image_threshold": 0, - }, - { - "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - }, - list, - (3, 2, 2, 2), -] +class TestRandCropByPosNegLabel(unittest.TestCase): + @staticmethod + def convert_data_type(im_type, d, keys=("img", "image", "label")): + out = deepcopy(d) + for k, v in out.items(): + if k in keys and isinstance(v, np.ndarray): + out[k] = im_type(v) + return out + @parameterized.expand(TESTS) + def test_type_shape(self, input_param, input_data, expected_shape): + results = [] + for p in TEST_NDARRAYS: + input_param_mod = self.convert_data_type(p, input_param) + input_data_mod = self.convert_data_type(p, input_data) + cropper = RandCropByPosNegLabel(**input_param_mod) + cropper.set_random_state(0) + result = cropper(**input_data_mod) -class TestRandCropByPosNegLabel(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_type_shape(self, input_param, input_data, expected_type, expected_shape): - result = RandCropByPosNegLabel(**input_param)(**input_data) - self.assertIsInstance(result, expected_type) - self.assertTupleEqual(result[0].shape, expected_shape) + self.assertIsInstance(result, list) + self.assertTupleEqual(result[0].shape, expected_shape) + + # check for same results across numpy, torch.Tensor and torch.cuda.Tensor + result = np.asarray([i if isinstance(i, np.ndarray) else i.cpu().numpy() for i in result]) + results.append(np.asarray(result)) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 17a3e117bb..03ec87bb4d 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -10,11 +10,13 @@ # limitations under the License. import unittest +from copy import deepcopy import numpy as np from parameterized import parameterized from monai.transforms import RandCropByPosNegLabeld +from tests.utils import TEST_NDARRAYS TEST_CASE_0 = [ { @@ -33,7 +35,6 @@ "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), "image_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, }, - list, (3, 3, 2, 2), ] @@ -54,7 +55,6 @@ "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), "label_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, }, - list, (3, 2, 2, 2), ] @@ -75,25 +75,36 @@ "label": np.ones([3, 3, 3, 3]), "extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, }, - list, (3, 2, 2, 2), ] class TestRandCropByPosNegLabeld(unittest.TestCase): + @staticmethod + def convert_data_type(im_type, d, keys=("img", "image", "label")): + out = deepcopy(d) + for k, v in out.items(): + if k in keys and isinstance(v, np.ndarray): + out[k] = im_type(v) + return out + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_type_shape(self, input_param, input_data, expected_type, expected_shape): - result = RandCropByPosNegLabeld(**input_param)(input_data) - self.assertIsInstance(result, expected_type) - self.assertTupleEqual(result[0]["image"].shape, expected_shape) - self.assertTupleEqual(result[0]["extra"].shape, expected_shape) - self.assertTupleEqual(result[0]["label"].shape, expected_shape) - _len = len(tuple(input_data.keys())) - self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) - for i, item in enumerate(result): - self.assertEqual(item["image_meta_dict"]["patch_index"], i) - self.assertEqual(item["label_meta_dict"]["patch_index"], i) - self.assertEqual(item["extra_meta_dict"]["patch_index"], i) + def test_type_shape(self, input_param, input_data, expected_shape): + for p in TEST_NDARRAYS: + input_param_mod = self.convert_data_type(p, input_param) + input_data_mod = self.convert_data_type(p, input_data) + cropper = RandCropByPosNegLabeld(**input_param_mod) + cropper.set_random_state(0) + result = cropper(input_data_mod) + + self.assertIsInstance(result, list) + + _len = len(tuple(input_data.keys())) + self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) + for k in ("image", "extra", "label"): + self.assertTupleEqual(result[0][k].shape, expected_shape) + for i, item in enumerate(result): + self.assertEqual(item[k + "_meta_dict"]["patch_index"], i) if __name__ == "__main__": From d70f3190859214b9638c247dc6a78f016c2ca4fb Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 21 Jul 2021 11:04:17 +0100 Subject: [PATCH 143/176] center scale crop Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 4 +-- tests/test_center_scale_crop.py | 39 +++++++++++++++++---------- tests/test_center_scale_cropd.py | 45 ++++++++++++++++++++----------- 3 files changed, 56 insertions(+), 32 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 183369a8f3..68900188cd 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -409,7 +409,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return cropper(img) -class CenterScaleCrop(Transform): +class CenterScaleCrop(TorchTransform, NumpyTransform): """ Crop at the center of image with specified scale of ROI size. @@ -422,7 +422,7 @@ class CenterScaleCrop(Transform): def __init__(self, roi_scale: Union[Sequence[float], float]): self.roi_scale = roi_scale - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: img_size = img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] diff --git a/tests/test_center_scale_crop.py b/tests/test_center_scale_crop.py index e28849ce90..cada8e273c 100644 --- a/tests/test_center_scale_crop.py +++ b/tests/test_center_scale_crop.py @@ -16,33 +16,44 @@ from parameterized import parameterized from monai.transforms import CenterScaleCrop +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [{"roi_scale": [0.6, 0.3, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 1, 3)] +TESTS, TEST_VALUES = [], [] +for p in TEST_NDARRAYS: + TESTS.append([{"roi_scale": [0.6, 0.3, -1]}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 1, 3)]) -TEST_CASE_1 = [{"roi_scale": 0.6}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 2)] + TESTS.append([{"roi_scale": 0.6}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 2, 2)]) -TEST_CASE_2 = [ - {"roi_scale": [0.4, 0.4]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2], [2, 3]]]), -] + TESTS.append( + [ + {"roi_scale": 0.5}, + p(torch.randint(0, 2, size=[3, 3, 3, 3])), + (3, 2, 2, 2), + ] + ) -TEST_CASE_3 = [ - {"roi_scale": 0.5}, - torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), - (3, 2, 2, 2), -] + TEST_VALUES.append( + [ + {"roi_scale": [0.4, 0.4]}, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])), + np.array([[[1, 2], [2, 3]]]), + ] + ) class TestCenterScaleCrop(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = CenterScaleCrop(**input_param)(input_data) np.testing.assert_allclose(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) + @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_data, expected_value): result = CenterScaleCrop(**input_param)(input_data) + self.assertEqual(type(result), type(input_data)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input_data.device) + result = result.cpu() np.testing.assert_allclose(result, expected_value) diff --git a/tests/test_center_scale_cropd.py b/tests/test_center_scale_cropd.py index 313e8e7f7e..c5b680b26c 100644 --- a/tests/test_center_scale_cropd.py +++ b/tests/test_center_scale_cropd.py @@ -16,34 +16,47 @@ from parameterized import parameterized from monai.transforms import CenterScaleCropd +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [{"keys": "img", "roi_scale": [0.6, 0.3, -1]}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 1, 3)] +TESTS, TEST_VALUES = [], [] +for p in TEST_NDARRAYS: + TESTS.append( + [{"keys": "img", "roi_scale": [0.6, 0.3, -1]}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 1, 3)] + ) -TEST_CASE_1 = [{"keys": "img", "roi_scale": 0.6}, np.random.randint(0, 2, size=[3, 3, 3, 3]), (3, 2, 2, 2)] + TESTS.append([{"keys": "img", "roi_scale": 0.6}, p(np.random.randint(0, 2, size=[3, 3, 3, 3])), (3, 2, 2, 2)]) -TEST_CASE_2 = [ - {"keys": "img", "roi_scale": [0.4, 0.4]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2], [2, 3]]]), -] + TESTS.append( + [ + {"keys": "img", "roi_scale": 0.5}, + p(torch.randint(0, 2, size=[3, 3, 3, 3])), + (3, 2, 2, 2), + ] + ) -TEST_CASE_3 = [ - {"keys": "img", "roi_scale": 0.5}, - torch.randint(0, 2, size=[3, 3, 3, 3], device="cuda" if torch.cuda.is_available() else "cpu"), - (3, 2, 2, 2), -] + TEST_VALUES.append( + [ + {"keys": "img", "roi_scale": [0.4, 0.4]}, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])), + np.array([[[1, 2], [2, 3]]]), + ] + ) class TestCenterScaleCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape): result = CenterScaleCropd(**input_param)({"img": input_data}) np.testing.assert_allclose(result["img"].shape, expected_shape) - @parameterized.expand([TEST_CASE_2]) + @parameterized.expand(TEST_VALUES) def test_value(self, input_param, input_data, expected_value): - result = CenterScaleCropd(**input_param)({"img": input_data}) - np.testing.assert_allclose(result["img"], expected_value) + result = CenterScaleCropd(**input_param)({"img": input_data})["img"] + self.assertEqual(type(result), type(input_data)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input_data.device) + result = result.cpu() + np.testing.assert_allclose(result, expected_value) if __name__ == "__main__": From d1fa1d7e1cb01a74f2e7de7b42f0b42a4210345a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 21 Jul 2021 11:35:05 +0100 Subject: [PATCH 144/176] resizewithpadorcrop Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 6 +++--- tests/test_resize_with_pad_or_crop.py | 24 ++++++++++++++++++------ tests/test_resize_with_pad_or_cropd.py | 20 +++++++++++++++++--- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 68900188cd..b50fb91308 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -1028,7 +1028,7 @@ def __call__( return results -class ResizeWithPadOrCrop(Transform): +class ResizeWithPadOrCrop(TorchTransform, NumpyTransform): """ Resize an image to a target spatial size by either centrally cropping the image or padding it evenly with a user-specified mode. @@ -1059,7 +1059,7 @@ def __init__( self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **np_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: + def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, str]] = None) -> DataObjects.Images: """ Args: img: data to pad or crop, assuming `img` is channel-first and @@ -1070,7 +1070,7 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N If None, defaults to the ``mode`` in construction. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ - return self.padder(self.cropper(img), mode=mode) # type: ignore + return self.padder(self.cropper(img), mode=mode) class BoundingRect(TorchTransform, NumpyTransform): diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index 46f1fc86cc..5f512a22f1 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch +from tests.utils import TEST_NDARRAYS import unittest import numpy as np @@ -48,12 +50,22 @@ class TestResizeWithPadOrCrop(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_pad_shape(self, input_param, input_shape, expected_shape): - paddcroper = ResizeWithPadOrCrop(**input_param) - result = paddcroper(np.zeros(input_shape)) - np.testing.assert_allclose(result.shape, expected_shape) - result = paddcroper(np.zeros(input_shape), mode="constant") - np.testing.assert_allclose(result.shape, expected_shape) - + results1, results2 = [], [] + for p in TEST_NDARRAYS: + im = p(np.zeros(input_shape)) + paddcroper = ResizeWithPadOrCrop(**input_param) + result1 = paddcroper(im) + result2 = paddcroper(im, mode="constant") + for result, results in zip((result1, result2), (results1, results2)): + self.assertEqual(type(im), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im.device) + result = result.cpu().numpy() + np.testing.assert_allclose(result.shape, expected_shape) + results.append(result) + # check output from numpy torch and torch.cuda match + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": unittest.main() diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index 32a62a9e16..9c2d55b5a2 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch +from tests.utils import TEST_NDARRAYS import unittest import numpy as np @@ -47,10 +49,22 @@ class TestResizeWithPadOrCropd(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_pad_shape(self, input_param, input_data, expected_val): + def test_pad_shape(self, input_param, input_data, expected_shape): paddcroper = ResizeWithPadOrCropd(**input_param) - result = paddcroper(input_data) - np.testing.assert_allclose(result["img"].shape, expected_val) + results = [] + for p in TEST_NDARRAYS: + input_data_mod = {"img": p(input_data["img"])} + result = paddcroper(input_data_mod) + r, i = result["img"], input_data_mod["img"] + self.assertEqual(type(i), type(r)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + r = r.cpu().numpy() + np.testing.assert_allclose(r.shape, expected_shape) + results.append(r) + # check output from numpy torch and torch.cuda match + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": From 87264a25c6db63bfb987352f9f4fa93fdfe35933 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 21 Jul 2021 16:00:20 +0100 Subject: [PATCH 145/176] randweightedcrop Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 48 +++-- monai/transforms/intensity/array.py | 4 - monai/transforms/utils.py | 31 ++- tests/test_rand_weighted_crop.py | 252 +++++++++++++----------- tests/test_rand_weighted_cropd.py | 260 +++++++++++++------------ tests/test_resize_with_pad_or_crop.py | 5 +- tests/test_resize_with_pad_or_cropd.py | 4 +- tests/test_spatial_crop.py | 29 ++- 8 files changed, 359 insertions(+), 274 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index b50fb91308..16daaf948a 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -333,10 +333,10 @@ class SpatialCrop(TorchTransform, NumpyTransform): def __init__( self, - roi_center: Union[Sequence[int], np.ndarray, None] = None, - roi_size: Union[Sequence[int], np.ndarray, None] = None, - roi_start: Union[Sequence[int], np.ndarray, None] = None, - roi_end: Union[Sequence[int], np.ndarray, None] = None, + roi_center: Optional[Union[Sequence[int], DataObjects.Images]] = None, + roi_size: Optional[Union[Sequence[int], DataObjects.Images]] = None, + roi_start: Optional[Union[Sequence[int], DataObjects.Images]] = None, + roi_end: Optional[Union[Sequence[int], DataObjects.Images]] = None, roi_slices: Optional[Sequence[slice]] = None, ) -> None: """ @@ -355,20 +355,21 @@ def __init__( self.slices = list(roi_slices) else: if roi_center is not None and roi_size is not None: - roi_center = np.asarray(roi_center, dtype=np.int16) - roi_size = np.asarray(roi_size, dtype=np.int16) - roi_start_np = np.maximum(roi_center - np.floor_divide(roi_size, 2), 0) - roi_end_np = np.maximum(roi_start_np + roi_size, roi_start_np) + roi_center = torch.as_tensor(roi_center, dtype=torch.int16) + roi_size = torch.as_tensor(roi_size, dtype=torch.int16, device=roi_center.device) + roi_start = torch.maximum( + roi_center - torch.div(roi_size, 2, rounding_mode="floor"), + torch.tensor(0, device=roi_center.device), + ) + roi_end = torch.maximum(roi_start + roi_size, roi_start) else: if roi_start is None or roi_end is None: raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.") - roi_start_np = np.maximum(np.asarray(roi_start, dtype=np.int16), 0) - roi_end_np = np.maximum(np.asarray(roi_end, dtype=np.int16), roi_start_np) - # Allow for 1D by converting back to np.array (since np.maximum will convert to int) - roi_start_np = roi_start_np if isinstance(roi_start_np, np.ndarray) else np.array([roi_start_np]) - roi_end_np = roi_end_np if isinstance(roi_end_np, np.ndarray) else np.array([roi_end_np]) + roi_start = torch.as_tensor(roi_start, dtype=torch.int16) + roi_start = torch.maximum(roi_start, torch.tensor(0, device=roi_start.device)) + roi_end = torch.maximum(torch.as_tensor(roi_end, dtype=torch.int16), roi_start) # convert to slices - self.slices = [slice(s, e) for s, e in zip(roi_start_np, roi_end_np)] + self.slices = [slice(s, e) for s, e in zip(roi_start, roi_end)] def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ @@ -708,7 +709,7 @@ def __call__(self, img: DataObjects.Images): return cropped -class RandWeightedCrop(Randomizable, Transform): +class RandWeightedCrop(Randomizable, TorchTransform, NumpyTransform): """ Samples a list of `num_samples` image patches according to the provided `weight_map`. @@ -722,19 +723,24 @@ class RandWeightedCrop(Randomizable, Transform): """ def __init__( - self, spatial_size: Union[Sequence[int], int], num_samples: int = 1, weight_map: Optional[np.ndarray] = None + self, + spatial_size: Union[Sequence[int], int], + num_samples: int = 1, + weight_map: Optional[DataObjects.Images] = None, ): self.spatial_size = ensure_tuple(spatial_size) self.num_samples = int(num_samples) self.weight_map = weight_map - self.centers: List[np.ndarray] = [] + self.centers: List[DataObjects.Images] = [] - def randomize(self, weight_map: np.ndarray) -> None: + def randomize(self, weight_map: DataObjects.Images) -> None: self.centers = weighted_patch_samples( spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R ) # using only the first channel as weight map - def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> List[np.ndarray]: + def __call__( + self, img: DataObjects.Images, weight_map: Optional[DataObjects.Images] = None + ) -> List[DataObjects.Images]: """ Args: img: input image to sample patches from. assuming `img` is a channel-first array. @@ -753,10 +759,10 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> raise ValueError(f"image and weight map spatial shape mismatch: {img.shape[1:]} vs {weight_map.shape[1:]}.") self.randomize(weight_map) _spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:]) - results: List[np.ndarray] = [] + results: List[DataObjects.Images] = [] for center in self.centers: cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) - results.append(cropper(img)) # type: ignore + results.append(cropper(img)) return results diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index e4239c3d25..a9aa1ba978 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -585,10 +585,6 @@ def _normalize(self, img: DataObjects.Images, sub=None, div=None) -> DataObjects _div[_div == 0.0] = 1.0 img[slices] = (img[slices] - _sub) / _div - if isinstance(img, torch.Tensor): - if img.dtype != torch.float32: - print("hi") - return img def __call__(self, img: DataObjects.Images) -> DataObjects.Images: diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index b80e121fca..5971ab0ea3 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -343,9 +343,19 @@ def map_classes_to_indices( return indices +def _unravel_index(idx, shape): + if isinstance(idx, torch.Tensor): + coord = [] + for dim in reversed(shape): + coord.insert(0, idx % dim) + idx = torch.div(idx, dim, rounding_mode="floor") + return torch.stack(coord) + return np.unravel_index(np.asarray(idx, dtype=int), shape) + + def weighted_patch_samples( spatial_size: Union[int, Sequence[int]], - w: np.ndarray, + w: DataObjects.Images, n_samples: int = 1, r_state: Optional[np.random.RandomState] = None, ) -> List: @@ -375,16 +385,25 @@ def weighted_patch_samples( v = w[s] # weight map in the 'valid' mode v_size = v.shape v = v.ravel() - if np.any(v < 0): + if (v < 0).any(): v -= np.min(v) # shifting to non-negative - v = v.cumsum() - if not v[-1] or not np.isfinite(v[-1]) or v[-1] < 0: # uniform sampling + v = v.cumsum(0) + if not v[-1] or not torch.as_tensor(v[-1]).isfinite() or v[-1] < 0: # uniform sampling idx = r_state.randint(0, len(v), size=n_samples) + if isinstance(v, torch.Tensor): + idx = torch.as_tensor(idx, device=v.device) else: - idx = v.searchsorted(r_state.random(n_samples) * v[-1], side="right") + r = r_state.random(n_samples) + if isinstance(v, np.ndarray): + idx = v.searchsorted(r * v[-1], side="right") + else: + r = torch.as_tensor(r, device=v.device) + idx = torch.searchsorted(v, r * v[-1], right=True) # compensate 'valid' mode diff = np.minimum(win_size, img_size) // 2 - return [np.unravel_index(i, v_size) + diff for i in np.asarray(idx, dtype=int)] + if isinstance(v, torch.Tensor): + diff = torch.as_tensor(diff, device=v.device) + return [_unravel_index(i, v_size) + diff for i in idx] def correct_crop_centers( diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py index 39a9439122..5ee815b711 100644 --- a/tests/test_rand_weighted_crop.py +++ b/tests/test_rand_weighted_crop.py @@ -12,127 +12,161 @@ import unittest import numpy as np +import torch +from parameterized.parameterized import parameterized from monai.transforms.croppad.array import RandWeightedCrop -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D -class TestRandWeightedCrop2D(NumpyImageTestCase2D): - def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCrop((10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[80, 21], [30, 17], [40, 31]]) - - def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((10, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 10, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32], [105, 32], [20, 32]]) - - def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCrop((10000, 400), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 10, 1] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 128, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[64, 32], [64, 32], [64, 32]]) - for res in result: - np.testing.assert_allclose(res, self.segn[0]) +def get_data(ndim): + im_gen = NumpyImageTestCase2D() if ndim == 2 else NumpyImageTestCase3D() + im_gen.setUp() + return im_gen.imt[0], im_gen.seg1[0], im_gen.segn[0] - def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((20, 40), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 20, 40)) - np.testing.assert_allclose(np.asarray(crop.centers), [[63, 37], [31, 43], [66, 20]]) +IMT_2D, SEG1_2D, SEGN_2D = get_data(ndim=2) +IMT_3D, SEG1_3D, SEGN_3D = get_data(ndim=3) -class TestRandWeightedCrop(NumpyImageTestCase3D): - def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCrop((8, 10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 5, 30, 17] = 1.1 - weight[0, 8, 40, 31] = 1 - weight[0, 11, 23, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 8, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[11, 23, 21], [5, 30, 17], [8, 40, 31]]) - def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) +TESTS = [] +for p in TEST_NDARRAYS: + im = SEG1_2D + weight = np.zeros_like(im) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + TESTS.append( + [ + "small roi 2d", + dict(spatial_size=(10, 12), num_samples=3), + p(im), + p(weight), + (1, 10, 12), + [[80, 21], [30, 17], [40, 31]], + ] + ) + im = IMT_2D + TESTS.append( + [ + "default roi 2d", + dict(spatial_size=(10, -1), num_samples=3), + p(im), + p(weight), + (1, 10, 64), + [[14, 32], [105, 32], [20, 32]], + ] + ) + im = SEGN_2D + weight = np.zeros_like(im) + weight[0, 30, 17] = 1.1 + weight[0, 10, 1] = 1 + TESTS.append( + [ + "large roi 2d", + dict(spatial_size=(10000, 400), num_samples=3), + p(im), + p(weight), + (1, 128, 64), + [[64, 32], [64, 32], [64, 32]], + ] + ) + im = IMT_2D + weight = np.zeros_like(im) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + TESTS.append( + [ + "bad w 2d", + dict(spatial_size=(20, 40), num_samples=3), + p(im), + p(weight), + (1, 20, 40), + [[63, 37], [31, 43], [66, 20]], + ] + ) + im = SEG1_3D + weight = np.zeros_like(im) + weight[0, 5, 30, 17] = 1.1 + weight[0, 8, 40, 31] = 1 + weight[0, 11, 23, 21] = 1 + TESTS.append( + [ + "small roi 3d", + dict(spatial_size=(8, 10, 12), num_samples=3), + p(im), + p(weight), + (1, 8, 10, 12), + [[11, 23, 21], [5, 30, 17], [8, 40, 31]], + ] + ) + im = IMT_3D + weight = np.zeros_like(im) + weight[0, 7, 17] = 1.1 + weight[0, 13, 31] = 1.1 + weight[0, 24, 21] = 1 + TESTS.append( + [ + "default roi 3d", + dict(spatial_size=(10, -1, -1), num_samples=3), + p(im), + p(weight), + (1, 10, 64, 80), + [[14, 32, 40], [41, 32, 40], [20, 32, 40]], + ] + ) + im = SEGN_3D + weight = np.zeros_like(im) + weight[0, 30, 17, 20] = 1.1 + weight[0, 10, 1, 17] = 1 + TESTS.append( + [ + "large roi 3d", + dict(spatial_size=(10000, 400, 80), num_samples=3), + p(im), + p(weight), + (1, 48, 64, 80), + [[24, 32, 40], [24, 32, 40], [24, 32, 40]], + ] + ) + im = IMT_3D + weight = np.zeros_like(im) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + TESTS.append( + [ + "bad w 3d", + dict(spatial_size=(48, 64, 80), num_samples=3), + p(im), + p(weight), + (1, 48, 64, 80), + [[24, 32, 40], [24, 32, 40], [24, 32, 40]], + ] + ) - def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCrop((10000, 400, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17, 20] = 1.1 - weight[0, 10, 1, 17] = 1 - crop.set_random_state(10) - result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) - for res in result: - np.testing.assert_allclose(res, self.segn[0]) - def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCrop((48, 64, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan +class TestRandWeightedCrop(unittest.TestCase): + @parameterized.expand(TESTS) + def test_rand_weighted_crop(self, _, input_params, img, weight, expected_shape, expected_vals): + crop = RandWeightedCrop(**input_params) crop.set_random_state(10) result = crop(img, weight) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + self.assertTrue(len(result) == input_params["num_samples"]) + np.testing.assert_allclose(result[0].shape, expected_shape) + for c, e in zip(crop.centers, expected_vals): + if isinstance(c, torch.Tensor): + c = c.cpu() + np.testing.assert_allclose(c, e) + # if desired ROI is larger than image, check image is unchanged + if all(s >= i for i, s in zip(img.shape[1:], input_params["spatial_size"])): + for res in result: + self.assertEqual(type(img), type(res)) + if isinstance(img, torch.Tensor): + self.assertEqual(res.device, img.device) + res = res.cpu() + np.testing.assert_allclose(res, img if isinstance(img, np.ndarray) else img.cpu()) if __name__ == "__main__": diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index 367ce3beb9..e77ee19b63 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -14,148 +14,168 @@ import numpy as np from monai.transforms.croppad.dictionary import RandWeightedCropd -from tests.utils import NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D class TestRandWeightedCrop(NumpyImageTestCase2D): def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - d = {"img": img, "w": weight} - result = crop(d) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[80, 21], [30, 17], [40, 31]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.seg1[0] + n_samples = 3 + crop = RandWeightedCropd("img", "w", (10, 12), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + crop.set_random_state(10) + d = {"img": p(img), "w": q(weight)} + result = crop(d) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 12)) + np.testing.assert_allclose(np.asarray(crop.centers), [[80, 21], [30, 17], [40, 31]]) def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd("im", "weight", (10, -1), n_samples, "coords") - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 40, 31] = 1 - weight[0, 80, 21] = 1 - crop.set_random_state(10) - data = {"im": img, "weight": weight, "others": np.nan} - result = crop(data) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["im"].shape, (1, 10, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32], [105, 32], [20, 32]]) - np.testing.assert_allclose(result[1]["coords"], [105, 32]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd("im", "weight", (10, -1), n_samples, "coords") + weight = np.zeros_like(img) + weight[0, 30, 17] = 1.1 + weight[0, 40, 31] = 1 + weight[0, 80, 21] = 1 + crop.set_random_state(10) + data = {"im": p(img), "weight": q(weight), "others": np.nan} + result = crop(data) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["im"].shape, (1, 10, 64)) + np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32], [105, 32], [20, 32]]) + np.testing.assert_allclose(result[1]["coords"], [105, 32]) def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "weight", (10000, 400), n_samples, "location") - weight = np.zeros_like(img) - weight[0, 30, 17] = 1.1 - weight[0, 10, 1] = 1 - crop.set_random_state(10) - data = {"img": img, "seg": self.imt[0], "weight": weight} - result = crop(data) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 128, 64)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 128, 64)) - np.testing.assert_allclose(np.asarray(crop.centers), [[64, 32], [64, 32], [64, 32]]) - np.testing.assert_allclose(result[1]["location"], [64, 32]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.segn[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "weight", (10000, 400), n_samples, "location") + weight = np.zeros_like(img) + weight[0, 30, 17] = 1.1 + weight[0, 10, 1] = 1 + crop.set_random_state(10) + data = {"img": p(img), "seg": p(self.imt[0]), "weight": q(weight)} + result = crop(data) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 128, 64)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 128, 64)) + np.testing.assert_allclose(np.asarray(crop.centers), [[64, 32], [64, 32], [64, 32]]) + np.testing.assert_allclose(result[1]["location"], [64, 32]) def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (20, 40), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 20, 40)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 20, 40)) - np.testing.assert_allclose(np.asarray(crop.centers), [[63, 37], [31, 43], [66, 20]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (20, 40), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + crop.set_random_state(10) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 20, 40)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 20, 40)) + np.testing.assert_allclose(np.asarray(crop.centers), [[63, 37], [31, 43], [66, 20]]) class TestRandWeightedCrop3D(NumpyImageTestCase3D): def test_rand_weighted_crop_small_roi(self): - img = self.seg1[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (8, 10, 12), n_samples) - weight = np.zeros_like(img) - weight[0, 5, 30, 17] = 1.1 - weight[0, 8, 40, 31] = 1 - weight[0, 11, 23, 21] = 1 - crop.set_random_state(10) - result = crop({"img": img, "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 8, 10, 12)) - np.testing.assert_allclose(np.asarray(crop.centers), [[11, 23, 21], [5, 30, 17], [8, 40, 31]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.seg1[0] + n_samples = 3 + crop = RandWeightedCropd("img", "w", (8, 10, 12), n_samples) + weight = np.zeros_like(img) + weight[0, 5, 30, 17] = 1.1 + weight[0, 8, 40, 31] = 1 + weight[0, 11, 23, 21] = 1 + crop.set_random_state(10) + result = crop({"img": p(img), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 8, 10, 12)) + np.testing.assert_allclose(np.asarray(crop.centers), [[11, 23, 21], [5, 30, 17], [8, 40, 31]]) def test_rand_weighted_crop_default_roi(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) + weight = np.zeros_like(img) + weight[0, 7, 17] = 1.1 + weight[0, 13, 31] = 1.1 + weight[0, 24, 21] = 1 + crop.set_random_state(10) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) def test_rand_weighted_crop_large_roi(self): - img = self.segn[0] - n_samples = 3 - crop = RandWeightedCropd("img", "w", (10000, 400, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17, 20] = 1.1 - weight[0, 10, 1, 17] = 1 - crop.set_random_state(10) - result = crop({"img": img, "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.segn[0] + n_samples = 3 + crop = RandWeightedCropd("img", "w", (10000, 400, 80), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17, 20] = 1.1 + weight[0, 10, 1, 17] = 1 + crop.set_random_state(10) + result = crop({"img": p(img), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) + np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) def test_rand_weighted_crop_bad_w(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (48, 64, 80), n_samples) - weight = np.zeros_like(img) - weight[0, 30, 17] = np.inf - weight[0, 10, 1] = -np.inf - weight[0, 10, 20] = -np.nan - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80)) - np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (48, 64, 80), n_samples) + weight = np.zeros_like(img) + weight[0, 30, 17] = np.inf + weight[0, 10, 1] = -np.inf + weight[0, 10, 20] = -np.nan + crop.set_random_state(10) + result = crop({"img": p(img), "seg": p(self.segn[0]), "w": q(weight)}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(result[0]["img"].shape, (1, 48, 64, 80)) + np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80)) + np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) def test_rand_weighted_crop_patch_index(self): - img = self.imt[0] - n_samples = 3 - crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) - weight = np.zeros_like(img) - weight[0, 7, 17] = 1.1 - weight[0, 13, 31] = 1.1 - weight[0, 24, 21] = 1 - crop.set_random_state(10) - result = crop({"img": img, "seg": self.segn[0], "w": weight, "img_meta_dict": {"affine": None}}) - self.assertTrue(len(result) == n_samples) - np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) - for i in range(n_samples): - np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80)) - np.testing.assert_allclose(result[i]["img_meta_dict"]["patch_index"], i) - np.testing.assert_allclose(result[i]["seg_meta_dict"]["patch_index"], i) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) + weight = np.zeros_like(img) + weight[0, 7, 17] = 1.1 + weight[0, 13, 31] = 1.1 + weight[0, 24, 21] = 1 + crop.set_random_state(10) + result = crop( + {"img": p(img), "seg": p(self.segn[0]), "w": q(weight), "img_meta_dict": {"affine": None}} + ) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) + for i in range(n_samples): + np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[i]["img_meta_dict"]["patch_index"], i) + np.testing.assert_allclose(result[i]["seg_meta_dict"]["patch_index"], i) if __name__ == "__main__": diff --git a/tests/test_resize_with_pad_or_crop.py b/tests/test_resize_with_pad_or_crop.py index 5f512a22f1..94a62292fe 100644 --- a/tests/test_resize_with_pad_or_crop.py +++ b/tests/test_resize_with_pad_or_crop.py @@ -9,14 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -from tests.utils import TEST_NDARRAYS import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ResizeWithPadOrCrop +from tests.utils import TEST_NDARRAYS TEST_CASES = [ [ @@ -67,5 +67,6 @@ def test_pad_shape(self, input_param, input_shape, expected_shape): if len(results) > 1: np.testing.assert_allclose(results[0], results[-1]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_resize_with_pad_or_cropd.py b/tests/test_resize_with_pad_or_cropd.py index 9c2d55b5a2..851ec18df7 100644 --- a/tests/test_resize_with_pad_or_cropd.py +++ b/tests/test_resize_with_pad_or_cropd.py @@ -9,14 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -from tests.utils import TEST_NDARRAYS import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ResizeWithPadOrCropd +from tests.utils import TEST_NDARRAYS TEST_CASES = [ [ diff --git a/tests/test_spatial_crop.py b/tests/test_spatial_crop.py index c76915f0a3..50c5b74139 100644 --- a/tests/test_spatial_crop.py +++ b/tests/test_spatial_crop.py @@ -16,8 +16,9 @@ from parameterized import parameterized from monai.transforms import SpatialCrop +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ +TESTS = [ [ {"roi_center": [1, 1, 1], "roi_size": [2, 2, 2]}, (3, 3, 3, 3), @@ -53,17 +54,25 @@ class TestSpatialCrop(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_shape, expected_shape): + results = [] input_data = np.random.randint(0, 2, size=input_shape) - result = SpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) - - @parameterized.expand(TEST_CASES) - def test_tensor_shape(self, input_param, input_shape, expected_shape): - input_data = torch.randint(0, 2, size=input_shape, device="cuda" if torch.cuda.is_available() else "cpu") - result = SpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS + (None,): + input_param_mod = { + k: q(v) if k != "roi_slices" and q is not None else v for k, v in input_param.items() + } + im = p(input_data) + result = SpatialCrop(**input_param_mod)(im) + self.assertEqual(type(im), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im.device) + result = result.cpu().numpy() + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) @parameterized.expand(TEST_ERRORS) def test_error(self, input_param): From 37da23efcd14a55c6d52eb73924e216cc408adfb Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 22 Jul 2021 10:03:15 +0100 Subject: [PATCH 146/176] merge dev into Feature/2231 transforms (#2640) * 2630 adds look_up_option (#2631) * adds look_up_options Signed-off-by: Wenqi Li * adds docstring example Signed-off-by: Wenqi Li * fixes tests Signed-off-by: Wenqi Li * update enum look ups Signed-off-by: Wenqi Li * fixes typos Signed-off-by: Wenqi Li * update based on comments Signed-off-by: Wenqi Li * [DLMED] enhance distance metrics warning (#2638) Signed-off-by: Nic Ma * randweightedcrop (partially done) Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> * fixes Pad type hint Signed-off-by: Wenqi Li Co-authored-by: Nic Ma Co-authored-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/grid_dataset.py | 4 +- monai/data/png_saver.py | 4 +- monai/data/png_writer.py | 4 +- monai/data/utils.py | 5 +- monai/engines/evaluator.py | 3 +- monai/handlers/utils.py | 5 +- monai/inferers/utils.py | 4 +- monai/losses/dice.py | 4 +- monai/metrics/hausdorff_distance.py | 5 + monai/metrics/rocauc.py | 4 +- monai/metrics/surface_distance.py | 5 + monai/metrics/utils.py | 11 +- monai/networks/blocks/upsample.py | 4 +- monai/networks/layers/factories.py | 6 +- monai/networks/layers/simplelayers.py | 5 +- monai/networks/layers/spatial_transforms.py | 6 +- monai/transforms/croppad/array.py | 21 ++-- monai/transforms/io/array.py | 7 +- monai/transforms/spatial/array.py | 66 ++++++------ monai/utils/__init__.py | 4 +- monai/utils/enums.py | 27 ----- monai/utils/module.py | 114 +++++++++++++++++++- tests/test_dynunet_block.py | 2 +- tests/test_get_layers.py | 6 ++ tests/test_look_up_option.py | 71 ++++++++++++ tests/test_segresnet_block.py | 2 +- tests/test_unetr_block.py | 2 +- 27 files changed, 291 insertions(+), 110 deletions(-) create mode 100644 tests/test_look_up_option.py diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index c4efe3ad2a..5b2a4d7abd 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -18,7 +18,7 @@ from monai.data.dataset import Dataset from monai.data.utils import iter_patch from monai.transforms import apply_transform -from monai.utils import NumpyPadMode, ensure_tuple +from monai.utils import NumpyPadMode, ensure_tuple, look_up_option __all__ = ["PatchDataset", "GridPatchDataset", "PatchIter"] @@ -57,7 +57,7 @@ def __init__( """ self.patch_size = (None,) + tuple(patch_size) self.start_pos = ensure_tuple(start_pos) - self.mode: NumpyPadMode = NumpyPadMode(mode) + self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) self.pad_opts = pad_opts def __call__(self, array): diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index 0d30b0132d..1ce787ba4e 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -17,7 +17,7 @@ from monai.data.png_writer import write_png from monai.data.utils import create_file_basename from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode +from monai.utils import InterpolateMode, look_up_option from monai.utils.enums import DataObjects @@ -75,7 +75,7 @@ def __init__( self.output_postfix = output_postfix self.output_ext = output_ext self.resample = resample - self.mode: InterpolateMode = InterpolateMode(mode) + self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.scale = scale self.data_root_dir = data_root_dir self.separate_folder = separate_folder diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index 3ab921bf73..ccb988d877 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -15,7 +15,7 @@ import torch from monai.transforms.spatial.array import Resize -from monai.utils import InterpolateMode, ensure_tuple_rep, optional_import +from monai.utils import InterpolateMode, ensure_tuple_rep, look_up_option, optional_import from monai.utils.enums import DataObjects from monai.utils.misc import convert_data_type @@ -58,7 +58,7 @@ def write_png( data_np = data_np.squeeze(2) if output_spatial_shape is not None: output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2) - mode = InterpolateMode(mode) + mode = look_up_option(mode, InterpolateMode) align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners) _min, _max = np.min(data_np), np.max(data_np) diff --git a/monai/data/utils.py b/monai/data/utils.py index 32db1a7170..344006ac64 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -37,6 +37,7 @@ fall_back_tuple, first, issequenceiterable, + look_up_option, optional_import, ) from monai.utils.enums import DataObjects, Method @@ -217,7 +218,7 @@ def iter_patch( start_pos = ensure_tuple_size(start_pos, arr.ndim) # pad image by maximum values needed to ensure patches are taken from inside an image - arrpad = np.pad(arr, tuple((p, p) for p in patch_size_), NumpyPadMode(mode).value, **pad_opts) + arrpad = np.pad(arr, tuple((p, p) for p in patch_size_), look_up_option(mode, NumpyPadMode).value, **pad_opts) # choose a start position in the padded image start_pos_padded = tuple(s + p for s, p in zip(start_pos, patch_size_)) @@ -770,7 +771,7 @@ def compute_importance_map( Tensor of size patch_size. """ - mode = BlendMode(mode) + mode = look_up_option(mode, BlendMode) device = torch.device(device) # type: ignore[arg-type] if mode == BlendMode.CONSTANT: importance_map = torch.ones(patch_size, device=device).float() diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 2754eb3c61..1c37da71d4 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -22,6 +22,7 @@ from monai.transforms import Transform from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys as Keys +from monai.utils.module import look_up_option if TYPE_CHECKING: from ignite.engine import Engine, EventEnum @@ -109,7 +110,7 @@ def __init__( event_to_attr=event_to_attr, decollate=decollate, ) - mode = ForwardMode(mode) + self.mode = look_up_option(mode, ForwardMode) if mode == ForwardMode.EVAL: self.mode = eval_mode elif mode == ForwardMode.TRAIN: diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 5bed5dac54..e67c979ea4 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -17,7 +17,7 @@ import torch from monai.config import IgniteInfo, KeysCollection -from monai.utils import deprecated, ensure_tuple, get_torch_version_tuple, min_version, optional_import +from monai.utils import deprecated, ensure_tuple, get_torch_version_tuple, look_up_option, min_version, optional_import from monai.utils.enums import DataObjects idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") @@ -214,7 +214,8 @@ class mean median max 5percentile 95percentile notnans def _compute_op(op: str, d: np.ndarray): if not op.endswith("percentile"): - return supported_ops[op](d) + c_op = look_up_option(op, supported_ops) + return c_op(d) threshold = int(op.split("percentile")[0]) return supported_ops["90percentile"]((d, threshold)) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 85779fc6d1..0ca53529c7 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -15,7 +15,7 @@ import torch.nn.functional as F from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size -from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple +from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option __all__ = ["sliding_window_inference"] @@ -103,7 +103,7 @@ def sliding_window_inference( diff = max(roi_size[k - 2] - inputs.shape[k], 0) half = diff // 2 pad_size.extend([half, diff - half]) - inputs = F.pad(inputs, pad=pad_size, mode=PytorchPadMode(padding_mode).value, value=cval) + inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval) scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 9cde049654..aa58490136 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -21,7 +21,7 @@ from monai.losses.focal_loss import FocalLoss from monai.losses.spatial_mask import MaskedLoss from monai.networks import one_hot -from monai.utils import LossReduction, Weight +from monai.utils import LossReduction, Weight, look_up_option from monai.utils.enums import DataObjects @@ -267,7 +267,7 @@ def __init__( self.softmax = softmax self.other_act = other_act - self.w_type = Weight(w_type) + self.w_type = look_up_option(w_type, Weight) self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 36ea05b570..06118b87b4 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -158,6 +158,11 @@ def compute_hausdorff_distance( hd = np.empty((batch_size, n_class)) for b, c in np.ndindex(batch_size, n_class): (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) + if not np.any(edges_gt): + warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") + if not np.any(edges_pred): + warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") + distance_1 = compute_percent_hausdorff_distance(edges_pred, edges_gt, distance_metric, percentile) if directed: hd[b, c] = distance_1 diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index 565ebbfc75..50aec9be12 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -14,7 +14,7 @@ import numpy as np import torch -from monai.utils import Average +from monai.utils import Average, look_up_option from .metric import CumulativeIterationMetric @@ -146,7 +146,7 @@ def compute_roc_auc( if y.shape != y_pred.shape: raise AssertionError("data shapes of y_pred and y do not match.") - average = Average(average) + average = look_up_option(average, Average) if average == Average.MICRO: return _calculate(y_pred.flatten(), y.flatten()) y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index 4cfbddac87..eacf63102b 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -153,6 +153,11 @@ def compute_average_surface_distance( for b, c in np.ndindex(batch_size, n_class): (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) + if not np.any(edges_gt): + warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.") + if not np.any(edges_pred): + warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.") + surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) if surface_distance.shape == (0,): avg_surface_distance = np.nan diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index a6fcc50be1..313cc7f844 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from typing import Tuple, Union import numpy as np @@ -17,7 +16,7 @@ from monai.transforms.croppad.array import SpatialCrop from monai.transforms.utils import generate_spatial_bounding_box -from monai.utils import MetricReduction, optional_import +from monai.utils import MetricReduction, look_up_option, optional_import from monai.utils.enums import DataObjects binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") @@ -71,7 +70,7 @@ def do_metric_reduction( not_nans = (~nans).float() t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) - reduction = MetricReduction(reduction) + reduction = look_up_option(reduction, MetricReduction) if reduction == MetricReduction.NONE: return f, not_nans @@ -189,15 +188,17 @@ def get_surface_distance( - ``"euclidean"``, uses Exact Euclidean distance transform. - ``"chessboard"``, uses `chessboard` metric in chamfer type of transform. - ``"taxicab"``, uses `taxicab` metric in chamfer type of transform. + + Note: + If seg_pred or seg_gt is all 0, may result in nan/inf distance. + """ if not np.any(seg_gt): dis = np.inf * np.ones_like(seg_gt) - warnings.warn("ground truth is all 0, this may result in nan/inf distance.") else: if not np.any(seg_pred): dis = np.inf * np.ones_like(seg_gt) - warnings.warn("prediction is all 0, this may result in nan/inf distance.") return np.asarray(dis[seg_gt]) if distance_metric == "euclidean": dis = distance_transform_edt(~seg_gt) diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py index 84027c229e..f3c680f050 100644 --- a/monai/networks/blocks/upsample.py +++ b/monai/networks/blocks/upsample.py @@ -16,7 +16,7 @@ from monai.networks.layers.factories import Conv, Pad, Pool from monai.networks.utils import icnr_init, pixelshuffle -from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep +from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option __all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"] @@ -78,7 +78,7 @@ def __init__( """ super().__init__() scale_factor_ = ensure_tuple_rep(scale_factor, dimensions) - up_mode = UpsampleMode(mode) + up_mode = look_up_option(mode, UpsampleMode) if up_mode == UpsampleMode.DECONV: if not in_channels: raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.") diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py index f8aa4ebe22..d4de08fc50 100644 --- a/monai/networks/layers/factories.py +++ b/monai/networks/layers/factories.py @@ -64,6 +64,8 @@ def use_factory(fact_args): import torch.nn as nn +from monai.utils import look_up_option + __all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"] @@ -120,8 +122,8 @@ def get_constructor(self, factory_name: str, *args) -> Any: if not isinstance(factory_name, str): raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.") - fact = self.factories[factory_name.upper()] - return fact(*args) + func = look_up_option(factory_name.upper(), self.factories) + return func(*args) def __getitem__(self, args) -> Any: """ diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py index 6737e54da7..52f19aab29 100644 --- a/monai/networks/layers/simplelayers.py +++ b/monai/networks/layers/simplelayers.py @@ -25,6 +25,7 @@ InvalidPyTorchVersionError, SkipMode, ensure_tuple_rep, + look_up_option, optional_import, ) @@ -75,7 +76,7 @@ def __init__( self.pad = None if in_channels == out_channels: return - mode = ChannelMatching(mode) + mode = look_up_option(mode, ChannelMatching) if mode == ChannelMatching.PROJECT: conv_type = Conv[Conv.CONV, spatial_dims] self.project = conv_type(in_channels, out_channels, kernel_size=1) @@ -119,7 +120,7 @@ def __init__(self, submodule, dim: int = 1, mode: Union[str, SkipMode] = "cat") super().__init__() self.submodule = submodule self.dim = dim - self.mode = SkipMode(mode).value + self.mode = look_up_option(mode, SkipMode).value def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.submodule(x) diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 03031f3340..511c24fcb0 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -15,7 +15,7 @@ import torch.nn as nn from monai.networks import to_norm_affine -from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, optional_import +from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, look_up_option, optional_import _C, _ = optional_import("monai._C") @@ -455,8 +455,8 @@ def __init__( super().__init__() self.spatial_size = ensure_tuple(spatial_size) if spatial_size is not None else None self.normalized = normalized - self.mode: GridSampleMode = GridSampleMode(mode) - self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) + self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) self.align_corners = align_corners self.reverse_indexing = reverse_indexing diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 16daaf948a..9ae7de43cf 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -34,7 +34,7 @@ map_classes_to_indices, weighted_patch_samples, ) -from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple +from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, look_up_option from monai.utils.enums import DataObjects from monai.utils.misc import convert_data_type @@ -78,11 +78,11 @@ class Pad(TorchTransform, NumpyTransform): def __init__( self, to_pad: List[Tuple[int, int]], - mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + mode: Union[NumpyPadMode, str, None] = NumpyPadMode.CONSTANT, **np_kwargs, ) -> None: self.to_pad = to_pad - self.mode = mode + self.mode = mode or NumpyPadMode.CONSTANT self.np_kwargs = np_kwargs @staticmethod @@ -156,8 +156,8 @@ def __init__( **np_kwargs, ) -> None: self.spatial_size = spatial_size - self.method: Method = Method(method) - self.mode: NumpyPadMode = NumpyPadMode(mode) + self.method: Method = look_up_option(method, Method) + self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) self.np_kwargs = np_kwargs def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]: @@ -180,13 +180,13 @@ def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, s One of the listed string values or a user supplied function. Defaults to ``self.mode``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ - mode = NumpyPadMode(mode or self.mode) data_pad_width = self._determine_data_pad_width(img.shape[1:]) all_pad_width = [(0, 0)] + data_pad_width if not np.asarray(all_pad_width).any(): # all zeros, skip padding return img - padder = Pad(all_pad_width, mode or self.mode, **self.np_kwargs) + mode = look_up_option(mode or self.mode, NumpyPadMode) + padder = Pad(all_pad_width, mode, **self.np_kwargs) return padder(img) @@ -222,7 +222,7 @@ def __init__( **np_kwargs, ) -> None: self.spatial_border = spatial_border - self.mode: NumpyPadMode = NumpyPadMode(mode) + self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) self.np_kwargs = np_kwargs def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, str]] = None) -> DataObjects.Images: @@ -259,7 +259,8 @@ def __call__(self, img: DataObjects.Images, mode: Optional[Union[NumpyPadMode, s f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]." ) all_pad_width = [(0, 0)] + data_pad_width - padder = Pad(all_pad_width, mode or self.mode, **self.np_kwargs) + mode = look_up_option(mode or self.mode, NumpyPadMode) + padder = Pad(all_pad_width, mode, **self.np_kwargs) return padder(img) @@ -664,7 +665,7 @@ def __init__( self.margin = margin self.return_coords = return_coords self.k_divisible = k_divisible - self.mode: NumpyPadMode = NumpyPadMode(mode) + self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) def compute_bounding_box(self, img: DataObjects.Images) -> Tuple[np.ndarray, np.ndarray]: """ diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 203fea35ba..557baca470 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -29,6 +29,7 @@ from monai.utils import InterpolateMode, ensure_tuple, optional_import from monai.utils.enums import DataObjects from monai.utils.misc import convert_data_type +from monai.utils.module import look_up_option nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -112,10 +113,8 @@ def __init__( "itkreader": ITKReader, "numpyreader": NumpyReader, } - reader = reader.lower() - if reader not in supported_readers: - raise ValueError(f"unsupported reader type: {reader}, available options: {supported_readers}.") - self.register(supported_readers[reader](*args, **kwargs)) + the_reader = look_up_option(reader.lower(), supported_readers) + self.register(the_reader(*args, **kwargs)) else: self.register(reader) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 2c5f073b53..e912182a13 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -48,6 +48,7 @@ ) from monai.utils.enums import DataObjects from monai.utils.misc import convert_data_type +from monai.utils.module import look_up_option nib, _ = optional_import("nibabel") @@ -127,8 +128,8 @@ def __init__( """ self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64) self.diagonal = diagonal - self.mode: GridSampleMode = GridSampleMode(mode) - self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) + self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) self.align_corners = align_corners self.dtype = dtype @@ -201,13 +202,12 @@ def __call__( if np.allclose(transform, np.diag(np.ones(len(transform))), atol=1e-3): output_data, *_ = convert_data_type(deepcopy(data_array), dtype=_dtype) new_affine = to_affine_nd(affine, new_affine) - else: # resample affine_xform = AffineTransform( normalized=False, - mode=mode or self.mode, - padding_mode=padding_mode or self.padding_mode, + mode=look_up_option(mode or self.mode, GridSampleMode), + padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), align_corners=self.align_corners if align_corners is None else align_corners, reverse_indexing=True, ) @@ -371,7 +371,7 @@ def __init__( align_corners: Optional[bool] = None, ) -> None: self.spatial_size = ensure_tuple(spatial_size) - self.mode: InterpolateMode = InterpolateMode(mode) + self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) self.align_corners = align_corners def __call__( @@ -410,7 +410,7 @@ def __call__( resized = torch.nn.functional.interpolate( input=img_t.unsqueeze(0), size=spatial_size, - mode=self.mode.value if mode is None else InterpolateMode(mode).value, + mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) resized = resized.squeeze(0) @@ -451,8 +451,8 @@ def __init__( ) -> None: self.angle = angle self.keep_size = keep_size - self.mode: GridSampleMode = GridSampleMode(mode) - self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) + self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) self.align_corners = align_corners self.dtype = dtype self._rotation_matrix: Optional[np.ndarray] = None @@ -512,8 +512,8 @@ def __call__( xform = AffineTransform( normalized=False, - mode=mode or self.mode, - padding_mode=padding_mode or self.padding_mode, + mode=look_up_option(mode or self.mode, GridSampleMode), + padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), align_corners=self.align_corners if align_corners is None else align_corners, reverse_indexing=True, ) @@ -603,7 +603,7 @@ def __call__( recompute_scale_factor=True, input=img_t.float().unsqueeze(0), scale_factor=list(_zoom), - mode=self.mode.value if mode is None else InterpolateMode(mode).value, + mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) zoomed = zoomed.squeeze(0) @@ -620,7 +620,8 @@ def __call__( elif diff < 0: # need slicing slice_vec[idx] = slice(half, half + od) - padder = Pad(pad_vec, padding_mode or self.padding_mode) + padding_mode = look_up_option(padding_mode or self.padding_mode, NumpyPadMode) + padder = Pad(pad_vec, padding_mode) zoomed = padder(zoomed) zoomed = zoomed[tuple(slice_vec)] @@ -749,8 +750,8 @@ def __init__( self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) self.keep_size = keep_size - self.mode: GridSampleMode = GridSampleMode(mode) - self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) + self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) self.align_corners = align_corners self.dtype = dtype @@ -793,8 +794,8 @@ def __call__( rotator = Rotate( angle=self.x if img.ndim == 3 else (self.x, self.y, self.z), keep_size=self.keep_size, - mode=mode or self.mode, - padding_mode=padding_mode or self.padding_mode, + mode=look_up_option(mode or self.mode, GridSampleMode), + padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, # type: ignore ) @@ -902,8 +903,8 @@ def __init__( self.max_zoom = ensure_tuple(max_zoom) if len(self.min_zoom) != len(self.max_zoom): raise AssertionError("min_zoom and max_zoom must have same length.") - self.mode: InterpolateMode = InterpolateMode(mode) - self.padding_mode: NumpyPadMode = NumpyPadMode(padding_mode) + self.mode: InterpolateMode = look_up_option(mode, InterpolateMode) + self.padding_mode: NumpyPadMode = look_up_option(padding_mode, NumpyPadMode) self.align_corners = align_corners self.keep_size = keep_size @@ -949,8 +950,8 @@ def __call__( zoomer = Zoom( self._zoom, keep_size=self.keep_size, - mode=mode or self.mode, - padding_mode=padding_mode or self.padding_mode, + mode=look_up_option(mode or self.mode, InterpolateMode), + padding_mode=look_up_option(padding_mode or self.padding_mode, NumpyPadMode), align_corners=align_corners or self.align_corners, ) return zoomer(img) @@ -1208,8 +1209,9 @@ def __init__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample device: device on which the tensor will be allocated. """ - self.mode: GridSampleMode = GridSampleMode(mode) - self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) + self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) + self.as_tensor_output = as_tensor_output self.device = device def __call__( @@ -1244,14 +1246,16 @@ def __call__( grid[i] += (dim - 1.0) / 2.0 grid = grid[:-1] / grid[-1:] grid = grid.permute(list(range(grid.ndimension()))[1:] + [0]) - _padding_mode = self.padding_mode.value if padding_mode is None else GridSamplePadMode(padding_mode).value + _padding_mode = look_up_option( + self.padding_mode if padding_mode is None else padding_mode, GridSamplePadMode + ).value if _padding_mode == "zeros": bound = 7 elif _padding_mode == "border": bound = 0 else: bound = 1 - _interp_mode = self.mode.value if mode is None else GridSampleMode(mode).value + _interp_mode = look_up_option(self.mode if mode is None else mode, GridSampleMode).value out = grid_pull( img_t.unsqueeze(0).float(), grid.unsqueeze(0).float(), @@ -1330,8 +1334,8 @@ def __init__( self.image_only = image_only self.resampler = Resample(device=device) self.spatial_size = spatial_size - self.mode: GridSampleMode = GridSampleMode(mode) - self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) + self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) def __call__( self, @@ -1585,8 +1589,8 @@ def __init__( self.resampler = Resample(device=device) self.spatial_size = spatial_size - self.mode: GridSampleMode = GridSampleMode(mode) - self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) + self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None @@ -1702,8 +1706,8 @@ def __init__( self.sigma_range = sigma_range self.magnitude_range = magnitude_range self.spatial_size = spatial_size - self.mode: GridSampleMode = GridSampleMode(mode) - self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) + self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) self.device = device self.rand_offset: np.ndarray diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index f99db70deb..1c52e8dce6 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -15,7 +15,6 @@ from .deprecated import DeprecatedError, deprecated, deprecated_arg from .dist import evenly_divisible_all_gather, get_dist_device, string_list_all_gather from .enums import ( - Activation, Average, BlendMode, ChannelMatching, @@ -29,7 +28,6 @@ LossReduction, Method, MetricReduction, - Normalization, NumpyPadMode, PytorchPadMode, SkipMode, @@ -65,12 +63,14 @@ PT_BEFORE_1_7, InvalidPyTorchVersionError, OptionalImportError, + damerau_levenshtein_distance, exact_version, export, get_full_type_name, get_package_version, get_torch_version_tuple, load_submodules, + look_up_option, min_version, optional_import, version_leq, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index f1d90b6063..398da78741 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -27,8 +27,6 @@ "MetricReduction", "LossReduction", "Weight", - "Normalization", - "Activation", "ChannelMatching", "SkipMode", "Method", @@ -177,31 +175,6 @@ class Weight(Enum): UNIFORM = "uniform" -class Normalization(Enum): - """ - See also: - - :py:class:`monai.networks.nets.ConvNormActi` - - :py:class:`monai.networks.nets.HighResBlock` - - :py:class:`monai.networks.nets.HighResNet` - """ - - BATCH = "batch" - INSTANCE = "instance" - - -class Activation(Enum): - """ - See also: - - :py:class:`monai.networks.nets.ConvNormActi` - - :py:class:`monai.networks.nets.HighResBlock` - - :py:class:`monai.networks.nets.HighResNet` - """ - - RELU = "relu" - PRELU = "prelu" - RELU6 = "relu6" - - class ChannelMatching(Enum): """ See also: :py:class:`monai.networks.nets.HighResBlock` diff --git a/monai/utils/module.py b/monai/utils/module.py index 2ccea2f05f..33314fb0e3 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -8,13 +8,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import sys import warnings from importlib import import_module from pkgutil import walk_packages from re import match -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, cast import torch @@ -25,6 +25,8 @@ "OptionalImportError", "exact_version", "export", + "damerau_levenshtein_distance", + "look_up_option", "min_version", "optional_import", "load_submodules", @@ -36,6 +38,114 @@ ] +def look_up_option(opt_str, supported: Collection, default="no_default"): + """ + Look up the option in the supported collection and return the matched item. + Raise a value error possibly with a guess of the closest match. + + Args: + opt_str: The option string or Enum to look up. + supported: The collection of supported options, it can be list, tuple, set, dict, or Enum. + default: If it is given, this method will return `default` when `opt_str` is not found, + instead of raising a `ValueError`. Otherwise, it defaults to `"no_default"`, + so that the method may raise a `ValueError`. + + Examples: + + .. code-block:: python + + from enum import Enum + from monai.utils import look_up_option + class Color(Enum): + RED = "red" + BLUE = "blue" + look_up_option("red", Color) # + look_up_option(Color.RED, Color) # + look_up_option("read", Color) + # ValueError: By 'read', did you mean 'red'? + # 'read' is not a valid option. + # Available options are {'blue', 'red'}. + look_up_option("red", {"red", "blue"}) # "red" + + Adapted from https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/utilities/util_common.py#L249 + """ + if not isinstance(opt_str, Hashable): + raise ValueError(f"Unrecognized option type: {type(opt_str)}:{opt_str}.") + if isinstance(opt_str, str): + opt_str = opt_str.strip() + if isinstance(supported, enum.EnumMeta): + if isinstance(opt_str, str) and opt_str in {item.value for item in cast(Iterable[enum.Enum], supported)}: + # such as: "example" in MyEnum + return supported(opt_str) + if isinstance(opt_str, enum.Enum) and opt_str in supported: + # such as: MyEnum.EXAMPLE in MyEnum + return opt_str + elif isinstance(supported, Mapping) and opt_str in supported: + # such as: MyDict[key] + return supported[opt_str] + elif isinstance(supported, Collection) and opt_str in supported: + return opt_str + + if default != "no_default": + return default + + # find a close match + set_to_check: set + if isinstance(supported, enum.EnumMeta): + set_to_check = {item.value for item in cast(Iterable[enum.Enum], supported)} + else: + set_to_check = set(supported) if supported is not None else set() + if not set_to_check: + raise ValueError(f"No options available: {supported}.") + edit_dists = {} + opt_str = f"{opt_str}" + for key in set_to_check: + edit_dist = damerau_levenshtein_distance(f"{key}", opt_str) + if edit_dist <= 3: + edit_dists[key] = edit_dist + + supported_msg = f"Available options are {set_to_check}.\n" + if edit_dists: + guess_at_spelling = min(edit_dists, key=edit_dists.get) # type: ignore + raise ValueError( + f"By '{opt_str}', did you mean '{guess_at_spelling}'?\n" + + f"'{opt_str}' is not a valid option.\n" + + supported_msg + ) + raise ValueError(f"Unsupported option '{opt_str}', " + supported_msg) + + +def damerau_levenshtein_distance(s1: str, s2: str): + """ + Calculates the Damerau–Levenshtein distance between two strings for spelling correction. + https://en.wikipedia.org/wiki/Damerau–Levenshtein_distance + """ + if s1 == s2: + return 0 + string_1_length = len(s1) + string_2_length = len(s2) + if not s1: + return string_2_length + if not s2: + return string_1_length + d = {(i, -1): i + 1 for i in range(-1, string_1_length + 1)} + for j in range(-1, string_2_length + 1): + d[(-1, j)] = j + 1 + + for i, s1i in enumerate(s1): + for j, s2j in enumerate(s2): + cost = 0 if s1i == s2j else 1 + d[(i, j)] = min( + d[(i - 1, j)] + 1, # deletion + d[(i, j - 1)] + 1, # insertion + d[(i - 1, j - 1)] + cost, # substitution + ) + if i and j and s1i == s2[j - 1] and s1[i - 1] == s2j: + d[(i, j)] = min(d[(i, j)], d[i - 2, j - 2] + cost) # transposition + + return d[string_1_length - 1, string_2_length - 1] + + def export(modname): """ Make the decorated object a member of the named module. This will also add the object under its aliases if it has diff --git a/tests/test_dynunet_block.py b/tests/test_dynunet_block.py index 97103a780c..7e832f6d81 100644 --- a/tests/test_dynunet_block.py +++ b/tests/test_dynunet_block.py @@ -76,7 +76,7 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): - with self.assertRaises(KeyError): + with self.assertRaises(ValueError): UnetBasicBlock(3, 4, 2, kernel_size=3, stride=1, norm_name="norm") with self.assertRaises(AssertionError): UnetResBlock(3, 4, 2, kernel_size=1, stride=4, norm_name="batch") diff --git a/tests/test_get_layers.py b/tests/test_get_layers.py index 2a95a0e582..e6ea810a6b 100644 --- a/tests/test_get_layers.py +++ b/tests/test_get_layers.py @@ -51,5 +51,11 @@ def test_dropout_layer(self, input_param, expected): self.assertEqual(f"{layer}", expected) +class TestSuggestion(unittest.TestCase): + def test_suggested(self): + with self.assertRaisesRegex(ValueError, "did you mean 'GROUP'?"): + get_norm_layer(name="grop", spatial_dims=2) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_look_up_option.py b/tests/test_look_up_option.py new file mode 100644 index 0000000000..60786f2fc5 --- /dev/null +++ b/tests/test_look_up_option.py @@ -0,0 +1,71 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from enum import Enum + +from parameterized import parameterized + +from monai.utils import look_up_option + + +class _CaseEnum(Enum): + CONST = "constant" + EMPTY = "empty" + + +class _CaseEnum1(Enum): + CONST = "constant" + EMPTY = "empty" + + +TEST_CASES = ( + ("test", ("test", "test1"), "test"), + ("test1", {"test1", "test"}, "test1"), + (2, {1: "test", 2: "valid"}, "valid"), + (_CaseEnum.EMPTY, _CaseEnum, _CaseEnum.EMPTY), + ("empty", _CaseEnum, _CaseEnum.EMPTY), +) + + +class TestLookUpOption(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_look_up(self, input_str, supported, expected): + output = look_up_option(input_str, supported) + self.assertEqual(output, expected) + + def test_default(self): + output = look_up_option("not here", {"a", "b"}, default=None) + self.assertEqual(output, None) + + def test_no_found(self): + with self.assertRaisesRegex(ValueError, "Unsupported"): + look_up_option("not here", {"a", "b"}) + with self.assertRaisesRegex(ValueError, "Unsupported"): + look_up_option("not here", ["a", "b"]) + with self.assertRaisesRegex(ValueError, "Unsupported"): + look_up_option("not here", {"a": 1, "b": 2}) + with self.assertRaisesRegex(ValueError, "did you mean"): + look_up_option(3, {1: "a", 2: "b", "c": 3}) + with self.assertRaisesRegex(ValueError, "did.*empty"): + look_up_option("empy", _CaseEnum) + with self.assertRaisesRegex(ValueError, "Unsupported"): + look_up_option(_CaseEnum1.EMPTY, _CaseEnum) + with self.assertRaisesRegex(ValueError, "Unsupported"): + look_up_option(None, _CaseEnum) + with self.assertRaisesRegex(ValueError, "No"): + look_up_option(None, None) + with self.assertRaisesRegex(ValueError, "No"): + look_up_option("test", None) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_segresnet_block.py b/tests/test_segresnet_block.py index bae7bffd2c..eb8cc9676b 100644 --- a/tests/test_segresnet_block.py +++ b/tests/test_segresnet_block.py @@ -46,7 +46,7 @@ def test_shape(self, input_param, input_shape, expected_shape): def test_ill_arg(self): with self.assertRaises(AssertionError): ResBlock(spatial_dims=3, in_channels=8, norm="group", kernel_size=2) - with self.assertRaises(KeyError): + with self.assertRaises(ValueError): ResBlock(spatial_dims=3, in_channels=8, norm="norm") diff --git a/tests/test_unetr_block.py b/tests/test_unetr_block.py index ba988d07e6..0b22838fae 100644 --- a/tests/test_unetr_block.py +++ b/tests/test_unetr_block.py @@ -111,7 +111,7 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): - with self.assertRaises(KeyError): + with self.assertRaises(ValueError): UnetrBasicBlock(3, 4, 2, kernel_size=3, stride=1, norm_name="norm") with self.assertRaises(AssertionError): UnetrBasicBlock(3, 4, 2, kernel_size=1, stride=4, norm_name="batch") From 60f35fce099aa0cdd6beb671365ef833e9b505c6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 22 Jul 2021 13:24:56 +0000 Subject: [PATCH 147/176] RandSpatialCrop, RandSpatialCropSamples Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 14 +-- monai/transforms/spatial/array.py | 1 - monai/utils/misc.py | 4 + tests/test_rand_scale_crop.py | 45 ++++++-- tests/test_rand_scale_cropd.py | 55 +++++++-- tests/test_rand_spatial_crop.py | 51 +++++++-- tests/test_rand_spatial_crop_samples.py | 108 ++++++++++-------- tests/test_rand_spatial_crop_samplesd.py | 138 +++++++++++++---------- 8 files changed, 276 insertions(+), 140 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 9ae7de43cf..cfe6125089 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -23,7 +23,7 @@ from monai.config import IndexSelection from monai.data.utils import get_random_patch, get_valid_patch_size -from monai.transforms.transform import NumpyTransform, Randomizable, TorchTransform, Transform +from monai.transforms.transform import NumpyTransform, Randomizable, TorchTransform from monai.transforms.utils import ( compute_divisible_spatial_size, generate_label_classes_crop_centers, @@ -432,7 +432,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return sp_crop(img=img) -class RandSpatialCrop(Randomizable, Transform): +class RandSpatialCrop(Randomizable, TorchTransform, NumpyTransform): """ Crop image with random size or specific size ROI. It can crop at a random position as center or at the image center. And allows to set the minimum and maximum size to limit the randomly generated ROI. @@ -481,7 +481,7 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, self._size) self._slices = (slice(None),) + get_random_patch(img_size, valid_size, self.R) - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -527,7 +527,7 @@ def __init__( self.roi_scale = roi_scale self.max_roi_scale = max_roi_scale - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. @@ -542,7 +542,7 @@ def __call__(self, img: np.ndarray): return super().__call__(img=img) -class RandSpatialCropSamples(Randomizable, Transform): +class RandSpatialCropSamples(Randomizable, TorchTransform, NumpyTransform): """ Crop image with random size or specific size ROI to generate a list of N samples. It can crop at a random position as center or at the image center. And allows to set @@ -593,10 +593,10 @@ def set_random_state( self.cropper.set_random_state(state=self.R) return self - def randomize(self, data: Optional[Any] = None) -> None: + def randomize(self, _: Optional[Any] = None) -> None: pass - def __call__(self, img: np.ndarray) -> List[np.ndarray]: + def __call__(self, img: DataObjects.Images) -> List[DataObjects.Images]: """ Apply the transform to `img`, assuming `img` is channel-first and cropping doesn't change the channel dim. diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index e912182a13..d5a2873b6b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1211,7 +1211,6 @@ def __init__( """ self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) - self.as_tensor_output = as_tensor_output self.device = device def __call__( diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 3d8ba2a4e3..5dcf71a0eb 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -129,6 +129,10 @@ def ensure_tuple_rep(tup: Any, dim: int) -> Tuple[Any, ...]: ValueError: Sequence must have length 3, got length 2. """ + if isinstance(tup, torch.Tensor): + tup = tup.cpu().numpy() + if isinstance(tup, np.ndarray): + tup = tup.tolist() if not issequenceiterable(tup): return (tup,) * dim if len(tup) == dim: diff --git a/tests/test_rand_scale_crop.py b/tests/test_rand_scale_crop.py index db5487ebff..a0870a12b5 100644 --- a/tests/test_rand_scale_crop.py +++ b/tests/test_rand_scale_crop.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandScaleCrop +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [ {"roi_scale": [1.0, 1.0, -1.0], "random_center": True}, @@ -55,22 +57,45 @@ class TestRandScaleCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape): - result = RandScaleCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + im = p(input_data) + cropper = RandScaleCrop(**input_param) + cropper.set_random_state(0) + result = cropper(im) + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandScaleCrop(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + im = p(input_data) + cropper = RandScaleCrop(**input_param) + cropper.set_random_state(0) + result = cropper(im) + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandScaleCrop(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result.shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + im = p(input_data) + cropper = RandScaleCrop(**input_param) + cropper.set_random_state(123) + result = cropper(im) + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_rand_scale_cropd.py b/tests/test_rand_scale_cropd.py index 265c6c467d..0f04bab484 100644 --- a/tests/test_rand_scale_cropd.py +++ b/tests/test_rand_scale_cropd.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandScaleCropd +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [ {"keys": "img", "roi_scale": [1.0, 1.0, -1.0], "random_center": True}, @@ -61,22 +63,55 @@ class TestRandScaleCropd(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape): - result = RandScaleCropd(**input_param)(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + data = {"img": p(input_data["img"])} + cropper = RandScaleCropd(**input_param) + cropper.set_random_state(0) + result = cropper(data)["img"] + self.assertTupleEqual(result.shape, expected_shape) + self.assertEqual(type(data["img"]), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, data["img"].device) + result = result.cpu().numpy() + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandScaleCropd(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result["img"], input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + results = [] + for p in TEST_NDARRAYS: + data = {"img": p(input_data["img"])} + cropper = RandScaleCropd(**input_param) + cropper.set_random_state(0) + result = cropper(data)["img"] + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + self.assertEqual(type(data["img"]), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, data["img"].device) + result = result.cpu().numpy() + np.testing.assert_allclose(result, input_data["img"][:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandScaleCropd(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result["img"].shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + data = {"img": p(input_data["img"])} + cropper = RandScaleCropd(**input_param) + cropper.set_random_state(seed=123) + result = cropper(data)["img"] + self.assertTupleEqual(result.shape, expected_shape) + self.assertEqual(type(data["img"]), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, data["img"].device) + result = result.cpu().numpy() + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop.py b/tests/test_rand_spatial_crop.py index 01e057e589..2897ce59f5 100644 --- a/tests/test_rand_spatial_crop.py +++ b/tests/test_rand_spatial_crop.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandSpatialCrop +from tests.utils import TEST_NDARRAYS TEST_CASE_0 = [ {"roi_size": [3, 3, -1], "random_center": True}, @@ -51,22 +53,51 @@ class TestRandSpatialCrop(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) def test_shape(self, input_param, input_data, expected_shape): - result = RandSpatialCrop(**input_param)(input_data) - self.assertTupleEqual(result.shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS + (None,): + im = p(input_data) + input_param_mod = {k: q(v) if q is not None else v for k, v in input_param.items()} + cropper = RandSpatialCrop(**input_param_mod) + cropper.set_random_state(0) + result = cropper(im) + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) @parameterized.expand([TEST_CASE_3]) def test_value(self, input_param, input_data): - cropper = RandSpatialCrop(**input_param) - result = cropper(input_data) - roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] - np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS + (None,): + im = p(input_data) + input_param_mod = {k: q(v) if q is not None else v for k, v in input_param.items()} + cropper = RandSpatialCrop(**input_param_mod) + cropper.set_random_state(0) + result = cropper(im) + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size] + np.testing.assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]]) @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) def test_random_shape(self, input_param, input_data, expected_shape): - cropper = RandSpatialCrop(**input_param) - cropper.set_random_state(seed=123) - result = cropper(input_data) - self.assertTupleEqual(result.shape, expected_shape) + results = [] + for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS + (None,): + im = p(input_data) + input_param_mod = {k: q(v) if q is not None else v for k, v in input_param.items()} + cropper = RandSpatialCrop(**input_param_mod) + cropper.set_random_state(123) + result = cropper(im) + if isinstance(result, torch.Tensor): + result = result.cpu().numpy() + self.assertTupleEqual(result.shape, expected_shape) + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop_samples.py b/tests/test_rand_spatial_crop_samples.py index 0ade9bbbba..d84247b4da 100644 --- a/tests/test_rand_spatial_crop_samples.py +++ b/tests/test_rand_spatial_crop_samples.py @@ -12,63 +12,80 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandSpatialCropSamples +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True, "random_size": False}, - np.arange(192).reshape(3, 4, 4, 4), - [(3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)], - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( [ - [ - [[21, 22, 23], [25, 26, 27], [29, 30, 31]], - [[37, 38, 39], [41, 42, 43], [45, 46, 47]], - [[53, 54, 55], [57, 58, 59], [61, 62, 63]], - ], - [ - [[85, 86, 87], [89, 90, 91], [93, 94, 95]], - [[101, 102, 103], [105, 106, 107], [109, 110, 111]], - [[117, 118, 119], [121, 122, 123], [125, 126, 127]], - ], - [ - [[149, 150, 151], [153, 154, 155], [157, 158, 159]], - [[165, 166, 167], [169, 170, 171], [173, 174, 175]], - [[181, 182, 183], [185, 186, 187], [189, 190, 191]], - ], + {"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True, "random_size": False}, + p(np.arange(192).reshape(3, 4, 4, 4)), + [(3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)], + np.array( + [ + [ + [[21, 22, 23], [25, 26, 27], [29, 30, 31]], + [[37, 38, 39], [41, 42, 43], [45, 46, 47]], + [[53, 54, 55], [57, 58, 59], [61, 62, 63]], + ], + [ + [[85, 86, 87], [89, 90, 91], [93, 94, 95]], + [[101, 102, 103], [105, 106, 107], [109, 110, 111]], + [[117, 118, 119], [121, 122, 123], [125, 126, 127]], + ], + [ + [[149, 150, 151], [153, 154, 155], [157, 158, 159]], + [[165, 166, 167], [169, 170, 171], [173, 174, 175]], + [[181, 182, 183], [185, 186, 187], [189, 190, 191]], + ], + ] + ), ] - ), -] + ) -TEST_CASE_2 = [ - {"roi_size": [3, 3, 3], "num_samples": 8, "random_center": False, "random_size": True}, - np.arange(192).reshape(3, 4, 4, 4), - [(3, 4, 4, 3), (3, 4, 3, 3), (3, 3, 4, 4), (3, 4, 4, 4), (3, 3, 3, 4), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)], - np.array( + TESTS.append( [ + {"roi_size": [3, 3, 3], "num_samples": 8, "random_center": False, "random_size": True}, + p(np.arange(192).reshape(3, 4, 4, 4)), [ - [[21, 22, 23], [25, 26, 27], [29, 30, 31]], - [[37, 38, 39], [41, 42, 43], [45, 46, 47]], - [[53, 54, 55], [57, 58, 59], [61, 62, 63]], - ], - [ - [[85, 86, 87], [89, 90, 91], [93, 94, 95]], - [[101, 102, 103], [105, 106, 107], [109, 110, 111]], - [[117, 118, 119], [121, 122, 123], [125, 126, 127]], - ], - [ - [[149, 150, 151], [153, 154, 155], [157, 158, 159]], - [[165, 166, 167], [169, 170, 171], [173, 174, 175]], - [[181, 182, 183], [185, 186, 187], [189, 190, 191]], + (3, 4, 4, 3), + (3, 4, 3, 3), + (3, 3, 4, 4), + (3, 4, 4, 4), + (3, 3, 3, 4), + (3, 3, 3, 3), + (3, 3, 3, 3), + (3, 3, 3, 3), ], + np.array( + [ + [ + [[21, 22, 23], [25, 26, 27], [29, 30, 31]], + [[37, 38, 39], [41, 42, 43], [45, 46, 47]], + [[53, 54, 55], [57, 58, 59], [61, 62, 63]], + ], + [ + [[85, 86, 87], [89, 90, 91], [93, 94, 95]], + [[101, 102, 103], [105, 106, 107], [109, 110, 111]], + [[117, 118, 119], [121, 122, 123], [125, 126, 127]], + ], + [ + [[149, 150, 151], [153, 154, 155], [157, 158, 159]], + [[165, 166, 167], [169, 170, 171], [173, 174, 175]], + [[181, 182, 183], [185, 186, 187], [189, 190, 191]], + ], + ] + ), ] - ), -] + ) class TestRandSpatialCropSamples(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_shape, expected_last_item): xform = RandSpatialCropSamples(**input_param) xform.set_random_state(1234) @@ -77,7 +94,10 @@ def test_shape(self, input_param, input_data, expected_shape, expected_last_item np.testing.assert_equal(len(result), input_param["num_samples"]) for item, expected in zip(result, expected_shape): self.assertTupleEqual(item.shape, expected) - np.testing.assert_allclose(result[-1], expected_last_item) + r = result[-1] + if isinstance(r, torch.Tensor): + r = r.cpu() + np.testing.assert_allclose(r, expected_last_item) if __name__ == "__main__": diff --git a/tests/test_rand_spatial_crop_samplesd.py b/tests/test_rand_spatial_crop_samplesd.py index 3f5eee7b27..47979838a4 100644 --- a/tests/test_rand_spatial_crop_samplesd.py +++ b/tests/test_rand_spatial_crop_samplesd.py @@ -12,58 +12,74 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import Compose, RandSpatialCropSamplesd, ToTensord +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["img", "seg"], "num_samples": 4, "roi_size": [2, 2, 2], "random_center": True}, - {"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)}, - [(3, 3, 3, 2), (3, 2, 2, 2), (3, 3, 3, 2), (3, 3, 2, 2)], - { - "img": np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["img", "seg"], "num_samples": 4, "roi_size": [2, 2, 2], "random_center": True}, + {"img": p(np.arange(81).reshape(3, 3, 3, 3)), "seg": p(np.arange(81, 0, -1).reshape(3, 3, 3, 3))}, + [(3, 3, 3, 2), (3, 2, 2, 2), (3, 3, 3, 2), (3, 3, 2, 2)], + { + "img": np.array( + [ + [[[0, 1], [3, 4]], [[9, 10], [12, 13]], [[18, 19], [21, 22]]], + [[[27, 28], [30, 31]], [[36, 37], [39, 40]], [[45, 46], [48, 49]]], + [[[54, 55], [57, 58]], [[63, 64], [66, 67]], [[72, 73], [75, 76]]], + ] + ), + "seg": np.array( + [ + [[[81, 80], [78, 77]], [[72, 71], [69, 68]], [[63, 62], [60, 59]]], + [[[54, 53], [51, 50]], [[45, 44], [42, 41]], [[36, 35], [33, 32]]], + [[[27, 26], [24, 23]], [[18, 17], [15, 14]], [[9, 8], [6, 5]]], + ] + ), + }, + ] + ) + TESTS.append( + [ + {"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False}, + {"img": p(np.arange(81).reshape(3, 3, 3, 3)), "seg": p(np.arange(81, 0, -1).reshape(3, 3, 3, 3))}, [ - [[[0, 1], [3, 4]], [[9, 10], [12, 13]], [[18, 19], [21, 22]]], - [[[27, 28], [30, 31]], [[36, 37], [39, 40]], [[45, 46], [48, 49]]], - [[[54, 55], [57, 58]], [[63, 64], [66, 67]], [[72, 73], [75, 76]]], - ] - ), - "seg": np.array( - [ - [[[81, 80], [78, 77]], [[72, 71], [69, 68]], [[63, 62], [60, 59]]], - [[[54, 53], [51, 50]], [[45, 44], [42, 41]], [[36, 35], [33, 32]]], - [[[27, 26], [24, 23]], [[18, 17], [15, 14]], [[9, 8], [6, 5]]], - ] - ), - }, -] - -TEST_CASE_2 = [ - {"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False}, - {"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)}, - [(3, 3, 3, 3), (3, 2, 3, 3), (3, 2, 2, 3), (3, 2, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 2, 2, 3), (3, 3, 2, 3)], - { - "img": np.array( - [ - [[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]], - [[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]], - [[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]], - ] - ), - "seg": np.array( - [ - [[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]], - [[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]], - [[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]], - ] - ), - }, -] + (3, 3, 3, 3), + (3, 2, 3, 3), + (3, 2, 2, 3), + (3, 2, 3, 3), + (3, 3, 3, 3), + (3, 3, 3, 3), + (3, 2, 2, 3), + (3, 3, 2, 3), + ], + { + "img": np.array( + [ + [[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]], + [[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]], + [[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]], + ] + ), + "seg": np.array( + [ + [[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]], + [[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]], + [[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]], + ] + ), + }, + ] + ) class TestRandSpatialCropSamplesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_shape(self, input_param, input_data, expected_shape, expected_last): + @parameterized.expand(TESTS) + def test_value(self, input_param, input_data, expected_shape, expected_last): xform = RandSpatialCropSamplesd(**input_param) xform.set_random_state(1234) result = xform(input_data) @@ -73,24 +89,30 @@ def test_shape(self, input_param, input_data, expected_shape, expected_last): for i, item in enumerate(result): self.assertEqual(item["img_meta_dict"]["patch_index"], i) self.assertEqual(item["seg_meta_dict"]["patch_index"], i) + self.assertEqual(type(item["img"]), type(input_data["img"])) + if isinstance(item["img"], torch.Tensor): + self.assertEqual(item["img"].device, input_data["img"].device) + item["img"] = item["img"].cpu() + item["seg"] = item["seg"].cpu() np.testing.assert_allclose(item["img"], expected_last["img"]) np.testing.assert_allclose(item["seg"], expected_last["seg"]) def test_deep_copy(self): - data = {"img": np.ones((1, 10, 11, 12))} - num_samples = 3 - sampler = RandSpatialCropSamplesd( - keys=["img"], - roi_size=(3, 3, 3), - num_samples=num_samples, - random_center=True, - random_size=False, - ) - transform = Compose([ToTensord(keys="img"), sampler]) - samples = transform(data) - self.assertEqual(len(samples), num_samples) - for sample in samples: - self.assertEqual(len(sample["img_transforms"]), len(transform)) + for p in TEST_NDARRAYS: + data = {"img": p(np.ones((1, 10, 11, 12)))} + num_samples = 3 + sampler = RandSpatialCropSamplesd( + keys=["img"], + roi_size=(3, 3, 3), + num_samples=num_samples, + random_center=True, + random_size=False, + ) + transform = Compose([ToTensord(keys="img"), sampler]) + samples = transform(data) + self.assertEqual(len(samples), num_samples) + for sample in samples: + self.assertEqual(len(sample["img_transforms"]), len(transform)) if __name__ == "__main__": From dcd226b7e2c6e605d65eac730f1163c63573c409 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 22 Jul 2021 15:12:16 +0000 Subject: [PATCH 148/176] k space spike Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 99 ++++++++++-------------- monai/transforms/intensity/dictionary.py | 28 ++----- tests/test_k_space_spike_noise.py | 40 ++++++---- tests/test_k_space_spike_noised.py | 54 ++++++++----- tests/test_rand_k_space_spike_noise.py | 59 ++++++++------ tests/test_rand_k_space_spike_noised.py | 61 +++++++++------ 6 files changed, 179 insertions(+), 162 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index a9aa1ba978..81a25c7123 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -24,7 +24,7 @@ from monai.config import DtypeLike from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter -from monai.transforms.transform import NumpyTransform, RandomizableTransform, TorchTransform, Transform +from monai.transforms.transform import NumpyTransform, RandomizableTransform, TorchTransform from monai.transforms.utils import rescale_array from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, ensure_tuple, ensure_tuple_rep, ensure_tuple_size from monai.utils.enums import DataObjects @@ -1362,7 +1362,7 @@ def _apply_mask(self, k: DataObjects.Images, data_type: type) -> DataObjects.Ima return out -class KSpaceSpikeNoise(Transform): +class KSpaceSpikeNoise(TorchTransform, NumpyTransform): """ Apply localized spikes in `k`-space at the given locations and intensities. Spike (Herringbone) artifact is a type of data acquisition artifact which @@ -1389,8 +1389,6 @@ class KSpaceSpikeNoise(Transform): receive a sequence of intensities. This value should be tested as it is data-dependent. The default values are the 2.5 the mean of the log-intensity for each channel. - as_tensor_output: if ``True`` return torch.Tensor, else return np.array. - Default: ``True``. Example: When working with 4D data, ``KSpaceSpikeNoise(loc = ((3,60,64,32), (64,60,32)), k_intensity = (13,14))`` @@ -1403,11 +1401,9 @@ def __init__( self, loc: Union[Tuple, Sequence[Tuple]], k_intensity: Optional[Union[Sequence[float], float]] = None, - as_tensor_output: bool = True, ): self.loc = ensure_tuple(loc) - self.as_tensor_output = as_tensor_output self.k_intensity = k_intensity # assert one-to-one relationship between factors and locations @@ -1422,7 +1418,7 @@ def __init__( if not isinstance(self.k_intensity, Sequence): raise AssertionError("There must be one intensity_factor value for each tuple of indices in loc.") - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: img (np.array or torch.tensor): image with dimensions (C, H, W) or (C, H, W, D) @@ -1438,23 +1434,17 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, raise AssertionError("Input images of dimension 4 need location tuple to be length 3 or 4") n_dims = len(img.shape[1:]) - - # convert to ndarray to work with np.fft - if isinstance(img, torch.Tensor): - device = img.device - img = img.cpu().detach().numpy() - else: - device = torch.device("cpu") + data_type = type(img) # FT - k = self._shift_fourier(img, n_dims) - log_abs = np.log(np.absolute(k) + 1e-10) - phase = np.angle(k) + k = self._shift_fourier(img, n_dims, data_type) + log_abs = lib.log(lib.absolute(k) + 1e-10) # type: ignore + phase = lib.angle(k) # type: ignore k_intensity = self.k_intensity # default log intensity if k_intensity is None: - k_intensity = tuple(np.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5) + k_intensity = tuple(lib.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5) # type: ignore # highlight if isinstance(self.loc[0], Sequence): @@ -1463,9 +1453,10 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, else: self._set_spike(log_abs, self.loc, k_intensity) # map back - k = np.exp(log_abs) * np.exp(1j * phase) - img = self._inv_shift_fourier(k, n_dims) - return torch.Tensor(img, device=device) if self.as_tensor_output else img + k = lib.exp(log_abs) * lib.exp(1j * phase) # type: ignore + img = self._inv_shift_fourier(k, n_dims, data_type) + + return img def _check_indices(self, img) -> None: """Helper method to check consistency of self.loc and input image. @@ -1485,7 +1476,7 @@ def _check_indices(self, img) -> None: f"The index value at position {i} of one of the tuples in loc = {self.loc} is out of bounds for current image." ) - def _set_spike(self, k: np.ndarray, idx: Tuple, val: Union[Sequence[float], float]): + def _set_spike(self, k: DataObjects.Images, idx: Tuple, val: Union[Sequence[float], float]): """ Helper function to introduce a given intensity at given location. @@ -1504,7 +1495,7 @@ def _set_spike(self, k: np.ndarray, idx: Tuple, val: Union[Sequence[float], floa elif len(k.shape) == 3 and len(idx) == 2: k[:, idx[0], idx[1]] = val - def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: + def _shift_fourier(self, x: DataObjects.Images, n_dims: int, data_type: type) -> DataObjects.Images: """ Applies fourier transform and shifts its output. Only the spatial dimensions get transformed. @@ -1512,21 +1503,27 @@ def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np. Args: x (np.ndarray): tensor to fourier transform. """ - out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) + # argument is dim if torch, else axes + mod, arg = (torch, "dim") if data_type is torch.Tensor else (np, "axes") + arg_dict = {arg: tuple(range(-n_dims, 0))} + out: DataObjects.Images = mod.fft.fftshift(mod.fft.fftn(x, **arg_dict), **arg_dict) # type: ignore return out - def _inv_shift_fourier(self, k: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray: + def _inv_shift_fourier(self, k: DataObjects.Images, n_dims: int, data_type: type) -> DataObjects.Images: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. """ - out: np.ndarray = np.fft.ifftn( - np.fft.ifftshift(k, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)) - ).real - return out + dims = tuple(range(-n_dims, 0)) + out: DataObjects.Images + if data_type is torch.Tensor: + out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward") + else: + out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims) + return out.real -class RandKSpaceSpikeNoise(RandomizableTransform): +class RandKSpaceSpikeNoise(RandomizableTransform, TorchTransform, NumpyTransform): """ Naturalistic data augmentation via spike artifacts. The transform applies localized spikes in `k`-space, and it is the random version of @@ -1555,8 +1552,6 @@ class RandKSpaceSpikeNoise(RandomizableTransform): log-intensity for each channel. channel_wise: treat each channel independently. True by default. - as_tensor_output: if True return torch.Tensor, else - return np.array. default: True. Example: To apply `k`-space spikes randomly with probability 0.5, and @@ -1570,12 +1565,10 @@ def __init__( prob: float = 0.1, intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None, channel_wise=True, - as_tensor_output: bool = True, ): self.intensity_range = intensity_range self.channel_wise = channel_wise - self.as_tensor_output = as_tensor_output self.sampled_k_intensity: List[float] = [] self.sampled_locs: List[Tuple] = [] @@ -1587,7 +1580,7 @@ def __init__( super().__init__(prob) - def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply transform to `img`. Assumes data is in channel-first form. @@ -1603,19 +1596,16 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, self.sampled_k_intensity = [] self.sampled_locs = [] - # convert to ndarray to work with np.fft - x, device = self._to_numpy(img) - intensity_range = self._make_sequence(x) - self._randomize(x, intensity_range) + intensity_range = self._make_sequence(img) + self._randomize(img, intensity_range) # build/apply transform only if there are spike locations if self.sampled_locs: - transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity, self.as_tensor_output) - return transform(x) - - return torch.Tensor(x, device=device) if self.as_tensor_output else x + transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity) + return transform(img) + return img - def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]]) -> None: + def _randomize(self, img: DataObjects.Images, intensity_range: Sequence[Sequence[float]]) -> None: """ Helper method to sample both the location and intensity of the spikes. When not working channel wise (channel_wise=False) it use the random @@ -1643,7 +1633,7 @@ def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]] else: self.sampled_k_intensity = [self.R.uniform(*self.intensity_range)] * len(img) # type: ignore - def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]: + def _make_sequence(self, x: DataObjects.Images) -> Sequence[Sequence[float]]: """ Formats the sequence of intensities ranges to Sequence[Sequence[float]]. """ @@ -1657,23 +1647,20 @@ def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]: # set default range if one not provided return self._set_default_range(x) - def _set_default_range(self, x: np.ndarray) -> Sequence[Sequence[float]]: + def _set_default_range(self, x: DataObjects.Images) -> Sequence[Sequence[float]]: """ Sets default intensity ranges to be sampled. Args: - x (np.ndarray): tensor to fourier transform. + x: tensor to fourier transform. """ n_dims = len(x.shape[1:]) - k = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) - log_abs = np.log(np.absolute(k) + 1e-10) - shifted_means = np.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5 + mod, arg = (torch, "dim") if type(x) is torch.Tensor else (np, "axes") + arg_dict = {arg: tuple(range(-n_dims, 0))} + k: DataObjects.Images = mod.fft.fftshift(mod.fft.fftn(x, **arg_dict), **arg_dict) # type: ignore + + log_abs = mod.log(mod.absolute(k) + 1e-10) # type: ignore + shifted_means = mod.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5 # type: ignore intensity_sequence = tuple((i * 0.95, i * 1.1) for i in shifted_means) return intensity_sequence - - def _to_numpy(self, img: Union[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, torch.device]: - if isinstance(img, torch.Tensor): - return img.cpu().detach().numpy(), img.device - else: - return img, torch.device("cpu") diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index d38fa8d857..5106ab0350 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -16,7 +16,7 @@ """ from collections.abc import Iterable -from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -1162,8 +1162,6 @@ class KSpaceSpikeNoised(MapTransform): receive a sequence of intensities. This value should be tested as it is data-dependent. The default values are the 2.5 the mean of the log-intensity for each channel. - as_tensor_output: if ``True`` return torch.Tensor, else return np.array. - Default: ``True``. allow_missing_keys: do not raise exception if key is missing. Example: @@ -1179,16 +1177,13 @@ def __init__( keys: KeysCollection, loc: Union[Tuple, Sequence[Tuple]], k_intensity: Optional[Union[Sequence[float], float]] = None, - as_tensor_output: bool = True, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.transform = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + self.transform = KSpaceSpikeNoise(loc, k_intensity) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: """ Args: data: Expects image/label to have dimensions (C, H, W) or @@ -1237,8 +1232,6 @@ class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform): common_sampling: If ``True`` same values for location and log-intensity will be sampled for the image and label. common_seed: Seed to be used in case ``common_sampling = True``. - as_tensor_output: if ``True`` return torch.Tensor, else return - np.array. Default: ``True``. allow_missing_keys: do not raise exception if key is missing. Example: @@ -1258,7 +1251,6 @@ def __init__( channel_wise: bool = True, common_sampling: bool = False, common_seed: int = 42, - as_tensor_output: bool = True, allow_missing_keys: bool = False, ): @@ -1267,14 +1259,11 @@ def __init__( self.common_sampling = common_sampling self.common_seed = common_seed - self.as_tensor_output = as_tensor_output # the spikes artifact is amplitude dependent so we instantiate one per key - self.t_img = RandKSpaceSpikeNoise(prob, img_intensity_range, channel_wise, self.as_tensor_output) - self.t_label = RandKSpaceSpikeNoise(prob, label_intensity_range, channel_wise, self.as_tensor_output) + self.t_img = RandKSpaceSpikeNoise(prob, img_intensity_range, channel_wise) + self.t_label = RandKSpaceSpikeNoise(prob, label_intensity_range, channel_wise) - def __call__( - self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]] - ) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: """ Args: data: Expects image/label to have dimensions (C, H, W) or @@ -1292,11 +1281,6 @@ def __call__( if self._do_transform: transform = self.t_img if key == "image" else self.t_label d[key] = transform(d[key]) - else: - if isinstance(d[key], np.ndarray) and self.as_tensor_output: - d[key] = torch.Tensor(d[key]) - elif isinstance(d[key], torch.Tensor) and not self.as_tensor_output: - d[key] = self._to_numpy(d[key]) return d def set_rand_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> None: diff --git a/tests/test_k_space_spike_noise.py b/tests/test_k_space_spike_noise.py index 53661d5fcb..66763f286f 100644 --- a/tests/test_k_space_spike_noise.py +++ b/tests/test_k_space_spike_noise.py @@ -20,12 +20,12 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) class TestKSpaceSpikeNoise(unittest.TestCase): @@ -37,34 +37,44 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d - im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + im, _ = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) + return im_type(im[None]) - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): - im = self.get_data(im_shape, as_tensor_input) + im = self.get_data(im_shape, im_type) loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0] k_intensity = 10 - t = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoise(loc, k_intensity) out1 = t(deepcopy(im)) out2 = t(deepcopy(im)) + self.assertEqual(type(im), type(out1)) + if isinstance(out1, torch.Tensor): + self.assertEqual(im.device, out1.device) + out1 = out1.cpu() + out2 = out2.cpu() + np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_highlighted_kspace_pixel(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_highlighted_kspace_pixel(self, im_shape, as_tensor_input): im = self.get_data(im_shape, as_tensor_input) loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0] k_intensity = 10 - t = KSpaceSpikeNoise(loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoise(loc, k_intensity) out = t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(im.device, out.device) + out = out.cpu() + n_dims = len(im_shape) out_k = fftshift(fftn(out, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) log_mag = np.log(np.absolute(out_k)) diff --git a/tests/test_k_space_spike_noised.py b/tests/test_k_space_spike_noised.py index e5d2dfb6f8..3fa6a394f3 100644 --- a/tests/test_k_space_spike_noised.py +++ b/tests/test_k_space_spike_noised.py @@ -20,12 +20,12 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoised from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) KEYS = ["image", "label"] @@ -39,55 +39,69 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [im[None] for im in ims] - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims + ims = [im_type(im[None]) for im in ims] return {k: v for k, v in zip(KEYS, ims)} - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out1 = t(deepcopy(data)) out2 = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() np.testing.assert_allclose(out1[k], out2[k]) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_highlighted_kspace_pixel(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_highlighted_kspace_pixel(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out = t(data) for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + n_dims = len(im_shape) out_k = fftshift(fftn(out[k], axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))) log_mag = np.log(np.absolute(out_k)) np.testing.assert_allclose(k_intensity, log_mag[tuple(loc)], 1e-1) - @parameterized.expand(TEST_CASES) - def test_dict_matches(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_dict_matches(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))] k_intensity = 10 - t = KSpaceSpikeNoised(KEYS, loc, k_intensity, as_tensor_output) + t = KSpaceSpikeNoised(KEYS, loc, k_intensity) out = t(deepcopy(data)) + for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) diff --git a/tests/test_rand_k_space_spike_noise.py b/tests/test_rand_k_space_spike_noise.py index ba9156c5b2..1c9ca9c1d5 100644 --- a/tests/test_rand_k_space_spike_noise.py +++ b/tests/test_rand_k_space_spike_noise.py @@ -19,13 +19,13 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import KSpaceSpikeNoise, RandKSpaceSpikeNoise from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - for channel_wise in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input, channel_wise)) + for p in TEST_NDARRAYS: + for channel_wise in (True, False): + TESTS.append((shape, p, channel_wise)) class TestRandKSpaceSpikeNoise(unittest.TestCase): @@ -37,44 +37,55 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None] - return torch.Tensor(im) if as_tensor_input else im + return im_type(im) - @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_0_prob(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 15] - t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise) out = t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(out.device, im.device) + im, out = im.cpu(), out.cpu() np.testing.assert_allclose(im, out) - @parameterized.expand(TEST_CASES) - def test_1_prob(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_1_prob(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 14] - t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise) out = t(im) - base_t = KSpaceSpikeNoise(t.sampled_locs, [14], as_tensor_output) + base_t = KSpaceSpikeNoise(t.sampled_locs, [14]) out = out - base_t(im) + self.assertEqual(type(im), type(out)) + if isinstance(out, torch.Tensor): + self.assertEqual(out.device, im.device) + im, out = im.cpu(), out.cpu() np.testing.assert_allclose(out, im * 0) - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 15] - t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise, as_tensor_output) + t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise) t.set_random_state(42) out1 = t(deepcopy(im)) t.set_random_state(42) out2 = t(deepcopy(im)) + self.assertEqual(type(im), type(out1)) + if isinstance(out1, torch.Tensor): + self.assertEqual(out1.device, im.device) + out1, out2 = out1.cpu(), out2.cpu() np.testing.assert_allclose(out1, out2) - self.assertIsInstance(out1, torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_intensity(self, im_shape, _, as_tensor_input, channel_wise): - im = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_intensity(self, im_shape, im_type, channel_wise): + im = self.get_data(im_shape, im_type) intensity_range = [14, 14.1] t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise) _ = t(deepcopy(im)) diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py index 3cb49f1c08..93b6ebfad8 100644 --- a/tests/test_rand_k_space_spike_noised.py +++ b/tests/test_rand_k_space_spike_noised.py @@ -19,12 +19,12 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d from monai.transforms import RandKSpaceSpikeNoised from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS -TEST_CASES = [] +TESTS = [] for shape in ((128, 64), (64, 48, 80)): - for as_tensor_output in (True, False): - for as_tensor_input in (True, False): - TEST_CASES.append((shape, as_tensor_output, as_tensor_input)) + for p in TEST_NDARRAYS: + TESTS.append((shape, p)) KEYS = ["image", "label"] @@ -38,17 +38,16 @@ def tearDown(self): set_determinism(None) @staticmethod - def get_data(im_shape, as_tensor_input): + def get_data(im_shape, im_type): create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5) - ims = [im[None] for im in ims] - ims = [torch.Tensor(im) for im in ims] if as_tensor_input else ims + ims = [im_type(im[None]) for im in ims] return {k: v for k, v in zip(KEYS, ims)} - @parameterized.expand(TEST_CASES) - def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_same_result(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) intensity_range = (13, 15) t = RandKSpaceSpikeNoised( @@ -58,7 +57,6 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): img_intensity_range=intensity_range, label_intensity_range=intensity_range, channel_wise=True, - as_tensor_output=as_tensor_output, ) t.set_rand_state(42) out1 = t(deepcopy(data)) @@ -67,12 +65,16 @@ def test_same_result(self, im_shape, as_tensor_output, as_tensor_input): out2 = t(deepcopy(data)) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() np.testing.assert_allclose(out1[k], out2[k], atol=1e-10) - self.assertIsInstance(out1[k], torch.Tensor if as_tensor_output else np.ndarray) - @parameterized.expand(TEST_CASES) - def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_0_prob(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) intensity_range = (13, 15) t1 = RandKSpaceSpikeNoised( KEYS, @@ -81,7 +83,6 @@ def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): img_intensity_range=intensity_range, label_intensity_range=intensity_range, channel_wise=True, - as_tensor_output=as_tensor_output, ) t2 = RandKSpaceSpikeNoised( @@ -91,19 +92,25 @@ def test_0_prob(self, im_shape, as_tensor_output, as_tensor_input): img_intensity_range=intensity_range, label_intensity_range=intensity_range, channel_wise=True, - as_tensor_output=as_tensor_output, ) out1 = t1(data) out2 = t2(data) for k in KEYS: + self.assertEqual(type(out1[k]), type(data[k])) + if isinstance(out1[k], torch.Tensor): + self.assertEqual(out1[k].device, data[k].device) + out1[k] = out1[k].cpu() + out2[k] = out2[k].cpu() + data[k] = data[k].cpu() + np.testing.assert_allclose(data[k], out1[k]) np.testing.assert_allclose(data[k], out2[k]) - @parameterized.expand(TEST_CASES) - def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): + @parameterized.expand(TESTS) + def test_intensity(self, im_shape, im_type): - data = self.get_data(im_shape, as_tensor_input) + data = self.get_data(im_shape, im_type) image_range = (15, 15.1) label_range = (14, 14.1) t = RandKSpaceSpikeNoised( @@ -113,7 +120,6 @@ def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): img_intensity_range=image_range, label_intensity_range=label_range, channel_wise=True, - as_tensor_output=True, ) _ = t(data) @@ -122,9 +128,9 @@ def test_intensity(self, im_shape, as_tensor_output, as_tensor_input): self.assertGreaterEqual(t.t_label.sampled_k_intensity[0], 14) self.assertLessEqual(t.t_label.sampled_k_intensity[0], 14.1) - @parameterized.expand(TEST_CASES) - def test_same_transformation(self, im_shape, _, as_tensor_input): - data = self.get_data(im_shape, as_tensor_input) + @parameterized.expand(TESTS) + def test_same_transformation(self, im_shape, im_type): + data = self.get_data(im_shape, im_type) # use same image for both dictionary entries to check same trans is applied to them data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])} @@ -138,11 +144,16 @@ def test_same_transformation(self, im_shape, _, as_tensor_input): label_intensity_range=label_range, channel_wise=True, common_sampling=True, - as_tensor_output=True, ) out = t(deepcopy(data)) + for k in KEYS: + self.assertEqual(type(out[k]), type(data[k])) + if isinstance(out[k], torch.Tensor): + self.assertEqual(out[k].device, data[k].device) + out[k] = out[k].cpu() + np.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]]) From 5cd43199ea31ae373397f528931a3c3a7be324e4 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 22 Jul 2021 15:12:36 +0000 Subject: [PATCH 149/176] rand bias Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 14 ++++-- tests/test_random_bias_field.py | 68 ++++++++++++++++++++--------- tests/test_random_bias_fieldd.py | 25 +++++++++++ 3 files changed, 82 insertions(+), 25 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 81a25c7123..50a0d0c8b0 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -417,7 +417,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return scaler(img) -class RandBiasField(RandomizableTransform): +class RandBiasField(RandomizableTransform, TorchTransform, NumpyTransform): """ Random bias field augmentation for MR images. The bias field is considered as a linear combination of smoothly varying basis (polynomial) @@ -481,14 +481,14 @@ def _generate_random_field( raise NotImplementedError("only supports 2D or 3D fields") return field - def randomize(self, data: np.ndarray) -> None: + def randomize(self, data: DataObjects.Images) -> None: super().randomize(None) self.spatial_shape = data.shape[1:] self.rank = len(self.spatial_shape) n_coeff = int(np.prod([(self.degree + k) / k for k in range(1, self.rank + 1)])) self._coeff = self.R.uniform(*self.coeff_range, n_coeff).tolist() - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Apply the transform to `img`. """ @@ -496,6 +496,7 @@ def __call__(self, img: np.ndarray): if not self._do_transform: return img num_channels = img.shape[0] + _bias_fields: DataObjects.Images _bias_fields = np.stack( [ self._generate_random_field( @@ -505,7 +506,12 @@ def __call__(self, img: np.ndarray): ], axis=0, ) - return (img * _bias_fields).astype(self.dtype) + _bias_fields, *_ = convert_data_type( + _bias_fields, type(img), dtype=self.dtype, device=img.device if isinstance(img, torch.Tensor) else None + ) + out = img * _bias_fields + out, *_ = convert_data_type(out, dtype=self.dtype) + return out class NormalizeIntensity(TorchTransform, NumpyTransform): diff --git a/tests/test_random_bias_field.py b/tests/test_random_bias_field.py index 16b4ab6917..9b25f20db8 100644 --- a/tests/test_random_bias_field.py +++ b/tests/test_random_bias_field.py @@ -12,9 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandBiasField +from tests.utils import TEST_NDARRAYS TEST_CASES_2D = [{}, (3, 32, 32)] TEST_CASES_3D = [{}, (3, 32, 32, 32)] @@ -30,36 +32,60 @@ class TestRandBiasField(unittest.TestCase): ] ) def test_output_shape(self, class_args, img_shape): - for degree in [1, 2, 3]: - bias_field = RandBiasField(degree=degree, **class_args) - img = np.random.rand(*img_shape) - output = bias_field(img) - np.testing.assert_equal(output.shape, img_shape) - np.testing.assert_equal(output.dtype, bias_field.dtype) + for p in TEST_NDARRAYS: + for degree in [1, 2, 3]: + bias_field = RandBiasField(degree=degree, **class_args) + img = p(np.random.rand(*img_shape)) + output = bias_field(img) + + self.assertEqual(type(img), type(output)) + if isinstance(output, torch.Tensor): + self.assertEqual(output.device, img.device) + output = output.cpu().numpy() - img_zero = np.zeros([*img_shape]) - output_zero = bias_field(img_zero) - np.testing.assert_equal(output_zero, img_zero) + np.testing.assert_equal(output.shape, img_shape) + np.testing.assert_equal(output.dtype, bias_field.dtype) + + img_zero = np.zeros([*img_shape]) + output_zero = bias_field(img_zero) + np.testing.assert_equal(output_zero, img_zero) @parameterized.expand([TEST_CASES_2D_ZERO_RANGE]) def test_zero_range(self, class_args, img_shape): - bias_field = RandBiasField(**class_args) - img = np.random.rand(*img_shape) - output = bias_field(img) - np.testing.assert_equal(output, np.zeros(img_shape)) + for p in TEST_NDARRAYS: + bias_field = RandBiasField(**class_args) + img = p(np.random.rand(*img_shape)) + output = bias_field(img) + + self.assertEqual(type(img), type(output)) + if isinstance(output, torch.Tensor): + self.assertEqual(output.device, img.device) + output = output.cpu().numpy() + np.testing.assert_equal(output, np.zeros(img_shape)) @parameterized.expand([TEST_CASES_2D_ONES]) def test_one_range_input(self, class_args, expected): - bias_field = RandBiasField(**class_args) - img = np.ones([1, 2, 2]) - output = bias_field(img) - np.testing.assert_equal(output, expected.astype(bias_field.dtype)) + for p in TEST_NDARRAYS: + bias_field = RandBiasField(**class_args) + img = p(np.ones([1, 2, 2])) + output = bias_field(img) + self.assertEqual(type(img), type(output)) + if isinstance(output, torch.Tensor): + self.assertEqual(output.device, img.device) + output = output.cpu().numpy() + np.testing.assert_equal(output, expected.astype(bias_field.dtype)) def test_zero_prob(self): - bias_field = RandBiasField(prob=0.0) - img = np.random.rand(3, 32, 32) - output = bias_field(img) - np.testing.assert_equal(output, img) + for p in TEST_NDARRAYS: + bias_field = RandBiasField(prob=0.0) + img = p(np.random.rand(3, 32, 32)) + output = bias_field(img) + self.assertEqual(type(img), type(output)) + if isinstance(output, torch.Tensor): + self.assertEqual(output.device, img.device) + output = output.cpu().numpy() + img = img.cpu().numpy() + np.testing.assert_equal(output, img) if __name__ == "__main__": diff --git a/tests/test_random_bias_fieldd.py b/tests/test_random_bias_fieldd.py index 136eb41f2e..29182569dc 100644 --- a/tests/test_random_bias_fieldd.py +++ b/tests/test_random_bias_fieldd.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandBiasFieldd @@ -34,6 +35,12 @@ def test_output_shape(self, class_args, img_shape): bias_field = RandBiasFieldd(keys=[key], **class_args) img = np.random.rand(*img_shape) output = bias_field({key: img}) + + self.assertEqual(type(img), type(output[key])) + if isinstance(output[key], torch.Tensor): + self.assertEqual(output[key].device, img.device) + output = output[key].cpu().numpy() + np.testing.assert_equal(output[key].shape, img_shape) np.testing.assert_equal(output[key].dtype, bias_field.rand_bias_field.dtype) @@ -43,6 +50,12 @@ def test_zero_range(self, class_args, img_shape): bias_field = RandBiasFieldd(keys=[key], **class_args) img = np.random.rand(*img_shape) output = bias_field({key: img}) + + self.assertEqual(type(img), type(output[key])) + if isinstance(output[key], torch.Tensor): + self.assertEqual(output[key].device, img.device) + output = output[key].cpu().numpy() + np.testing.assert_equal(output[key], np.zeros(img_shape)) @parameterized.expand([TEST_CASES_2D_ONES]) @@ -51,6 +64,12 @@ def test_one_range_input(self, class_args, expected): bias_field = RandBiasFieldd(keys=[key], **class_args) img = np.ones([1, 2, 2]) output = bias_field({key: img}) + + self.assertEqual(type(img), type(output[key])) + if isinstance(output[key], torch.Tensor): + self.assertEqual(output[key].device, img.device) + output = output[key].cpu().numpy() + np.testing.assert_equal(output[key], expected.astype(bias_field.rand_bias_field.dtype)) def test_zero_prob(self): @@ -58,6 +77,12 @@ def test_zero_prob(self): bias_field = RandBiasFieldd(keys=[key], prob=0.0) img = np.random.rand(3, 32, 32) output = bias_field({key: img}) + + self.assertEqual(type(img), type(output[key])) + if isinstance(output[key], torch.Tensor): + self.assertEqual(output[key].device, img.device) + output = output[key].cpu().numpy() + np.testing.assert_equal(output[key], img) From ae673c922c21e03b2c178a7aa74a7e1dc569f91a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 22 Jul 2021 16:02:04 +0000 Subject: [PATCH 150/176] vote ensemble Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/post/array.py | 41 ++++++++++----- monai/transforms/post/dictionary.py | 11 ++-- tests/test_mean_ensemble.py | 19 ++++++- tests/test_mean_ensembled.py | 81 +++++++++++++++++------------ 4 files changed, 100 insertions(+), 52 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 0ca4d8122f..54e7c59bd8 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -23,7 +23,7 @@ from monai.networks import one_hot from monai.networks.layers import GaussianFilter -from monai.transforms.transform import TorchTransform, Transform +from monai.transforms.transform import NumpyTransform, TorchTransform, Transform from monai.transforms.utils import get_largest_connected_component_mask from monai.utils import ensure_tuple from monai.utils.enums import DataObjects @@ -361,7 +361,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: return out -class MeanEnsemble(Transform): +class MeanEnsemble(TorchTransform, NumpyTransform): """ Execute mean ensemble on the input data. The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], @@ -384,24 +384,41 @@ class MeanEnsemble(Transform): """ - def __init__(self, weights: Optional[Union[Sequence[float], torch.Tensor, np.ndarray]] = None) -> None: - self.weights = torch.as_tensor(weights, dtype=torch.float) if weights is not None else None + def __init__(self, weights: Optional[Union[Sequence[float], DataObjects.Images]] = None) -> None: + if weights is None: + self.weights = None + elif isinstance(weights, (torch.Tensor, np.ndarray)): + self.weights = weights + else: + self.weights = torch.as_tensor(weights, dtype=torch.float) + + def __call__(self, img: Union[Sequence[DataObjects.Images], DataObjects.Images]) -> DataObjects.Images: + if isinstance(img, (torch.Tensor, np.ndarray)): + img_ = img + elif isinstance(img[0], torch.Tensor): + img_ = torch.stack(img) # type: ignore + else: + img_ = np.stack(img) - def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: - img_ = torch.stack(img) if isinstance(img, (tuple, list)) else torch.as_tensor(img) if self.weights is not None: - self.weights = self.weights.to(img_.device) + self.weights, *_ = convert_data_type( + self.weights, type(img_), device=img_.device if isinstance(img_, torch.Tensor) else None + ) shape = tuple(self.weights.shape) - for _ in range(img_.ndimension() - self.weights.ndimension()): + for _ in range(img_.ndim - self.weights.ndim): shape += (1,) weights = self.weights.reshape(*shape) - img_ = img_ * weights / weights.mean(dim=0, keepdim=True) + if isinstance(img_, torch.Tensor): + # torch can only do the mean on floats + img_ = img_ * weights / weights.float().mean(dim=0, keepdim=True) # type: ignore + else: + img_ = img_ * weights / weights.mean(axis=0, keepdims=True) # type: ignore - return torch.mean(img_, dim=0) + return img_.mean(0) # type: ignore -class VoteEnsemble(TorchTransform): +class VoteEnsemble(Transform): """ Execute vote ensemble on the input data. The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], @@ -424,7 +441,7 @@ class VoteEnsemble(TorchTransform): def __init__(self, num_classes: Optional[int] = None) -> None: self.num_classes = num_classes - def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor: + def __call__(self, img: Union[Sequence[DataObjects.Images], DataObjects.Images]) -> DataObjects.Images: img_ = ( torch.stack([torch.as_tensor(i) for i in img]) if isinstance(img, (tuple, list)) else torch.as_tensor(img) ) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index a3ad56730e..90d38396e5 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -17,9 +17,8 @@ import warnings from copy import deepcopy -from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Union +from typing import Callable, List, Optional, Sequence, Union -import numpy as np import torch from monai.config import KeysCollection @@ -241,7 +240,7 @@ class Ensembled(MapTransform): def __init__( self, keys: KeysCollection, - ensemble: Callable[[Union[Sequence[torch.Tensor], torch.Tensor]], torch.Tensor], + ensemble: Callable[[Union[Sequence[DataObjects.Images], DataObjects.Images]], DataObjects.Images], output_key: Optional[str] = None, allow_missing_keys: bool = False, ) -> None: @@ -267,9 +266,9 @@ def __init__( raise ValueError("Incompatible values: len(self.keys) > 1 and output_key=None.") self.output_key = output_key if output_key is not None else self.keys[0] - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: d = dict(data) - items: Union[List[torch.Tensor], torch.Tensor] + items: Union[List[DataObjects.Images], DataObjects.Images] if len(self.keys) == 1: items = d[self.keys[0]] else: @@ -288,7 +287,7 @@ def __init__( self, keys: KeysCollection, output_key: Optional[str] = None, - weights: Optional[Union[Sequence[float], torch.Tensor, np.ndarray]] = None, + weights: Optional[Union[Sequence[float], DataObjects.Images]] = None, ) -> None: """ Args: diff --git a/tests/test_mean_ensemble.py b/tests/test_mean_ensemble.py index 7e08846beb..6bc180bb1e 100644 --- a/tests/test_mean_ensemble.py +++ b/tests/test_mean_ensemble.py @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.transforms import MeanEnsemble +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [ {"weights": None}, @@ -57,8 +58,22 @@ class TestMeanEnsemble(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_value(self, input_param, img, expected_value): - result = MeanEnsemble(**input_param)(img) - torch.testing.assert_allclose(result, expected_value) + for i, p in enumerate(TEST_NDARRAYS): + if isinstance(img, list): + im = [p(i) for i in img] + im_type = type(im[0]) + im_device = im[0].device if isinstance(im[0], torch.Tensor) else None + else: + im = p(img) + im_type = type(im) + im_device = im.device if isinstance(im, torch.Tensor) else None + + result = MeanEnsemble(**input_param)(im) + self.assertEqual(im_type, type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im_device) + result = result.cpu() + np.testing.assert_allclose(result, expected_value) def test_cuda_value(self): img = torch.stack([torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2]) diff --git a/tests/test_mean_ensembled.py b/tests/test_mean_ensembled.py index ea77ef18a0..4505d34944 100644 --- a/tests/test_mean_ensembled.py +++ b/tests/test_mean_ensembled.py @@ -16,49 +16,66 @@ from parameterized import parameterized from monai.transforms import MeanEnsembled +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": None}, - {"pred0": torch.ones(2, 2, 2), "pred1": torch.ones(2, 2, 2) + 2}, - torch.ones(2, 2, 2) + 1, -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": None}, + {"pred0": p(torch.ones(2, 2, 2)), "pred1": p(torch.ones(2, 2, 2) + 2)}, + torch.ones(2, 2, 2) + 1, + ] + ) -TEST_CASE_2 = [ - {"keys": "output", "weights": None}, - {"output": torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2])}, - torch.ones(2, 2, 2) + 1, -] + TESTS.append( + [ + {"keys": "output", "weights": None}, + {"output": p(torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2]))}, + torch.ones(2, 2, 2) + 1, + ] + ) -TEST_CASE_3 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [1, 3]}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * 2.5, -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [1, 3]}, + {"pred0": p(torch.ones(2, 2, 2, 2)), "pred1": p(torch.ones(2, 2, 2, 2) + 2)}, + torch.ones(2, 2, 2, 2) * 2.5, + ] + ) -TEST_CASE_4 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [[1, 3], [3, 1]]}, - {"pred0": torch.ones(2, 2, 2), "pred1": torch.ones(2, 2, 2) + 2}, - torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [[1, 3], [3, 1]]}, + {"pred0": p(torch.ones(2, 2, 2)), "pred1": p(torch.ones(2, 2, 2) + 2)}, + torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), + ] + ) -TEST_CASE_5 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": np.array([[[1, 3]], [[3, 1]]])}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": np.array([[[1, 3]], [[3, 1]]])}, + {"pred0": p(torch.ones(2, 2, 2, 2)), "pred1": p(torch.ones(2, 2, 2, 2) + 2)}, + torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), + ] + ) -TEST_CASE_6 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": torch.tensor([[[1, 3]], [[3, 1]]])}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), -] + TESTS.append( + [ + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": torch.tensor([[[1, 3]], [[3, 1]]])}, + {"pred0": p(torch.ones(2, 2, 2, 2)), "pred1": p(torch.ones(2, 2, 2, 2) + 2)}, + torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), + ] + ) class TestMeanEnsembled(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_param, data, expected_value): result = MeanEnsembled(**input_param)(data) - torch.testing.assert_allclose(result["output"], expected_value) + if isinstance(result["output"], torch.Tensor): + result["output"] = result["output"].cpu() + np.testing.assert_allclose(result["output"], expected_value) def test_cuda_value(self): img = torch.stack([torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2]) From 9264f1b3aeb47ba709d7fbb84a51f458d9898174 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 23 Jul 2021 12:13:14 +0100 Subject: [PATCH 151/176] post merge fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/metrics/utils.py | 3 ++- monai/transforms/utils.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index b33d9a68a6..ffb52ffe49 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -16,7 +16,8 @@ from monai.transforms.croppad.array import SpatialCrop from monai.transforms.utils import generate_spatial_bounding_box -from monai.utils.enums import DataObjects +from monai.utils.enums import DataObjects, MetricReduction +from monai.utils.module import look_up_option, optional_import binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 5971ab0ea3..cb1da8fd24 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -388,11 +388,13 @@ def weighted_patch_samples( if (v < 0).any(): v -= np.min(v) # shifting to non-negative v = v.cumsum(0) + idx: DataObjects.Images if not v[-1] or not torch.as_tensor(v[-1]).isfinite() or v[-1] < 0: # uniform sampling idx = r_state.randint(0, len(v), size=n_samples) if isinstance(v, torch.Tensor): idx = torch.as_tensor(idx, device=v.device) else: + r: DataObjects.Images r = r_state.random(n_samples) if isinstance(v, np.ndarray): idx = v.searchsorted(r * v[-1], side="right") @@ -400,6 +402,7 @@ def weighted_patch_samples( r = torch.as_tensor(r, device=v.device) idx = torch.searchsorted(v, r * v[-1], right=True) # compensate 'valid' mode + diff: DataObjects.Images diff = np.minimum(win_size, img_size) // 2 if isinstance(v, torch.Tensor): diff = torch.as_tensor(diff, device=v.device) From af4252c5b656ae5fab605e8a75ff422acddd2452 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 23 Jul 2021 13:38:40 +0100 Subject: [PATCH 152/176] one hot for np as well Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/utils.py | 25 ++++++++++++++++++++----- tests/test_to_onehot.py | 36 ++++++++++++++++++++++++++---------- 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 9d20d2a83b..a9820ff55c 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -17,9 +17,14 @@ from contextlib import contextmanager from typing import Any, Callable, Mapping, Optional, Sequence, Union +import numpy as np import torch import torch.nn as nn +from monai.config.type_definitions import DtypeLike +from monai.utils.enums import DataObjects +from monai.utils.misc import dtype_convert + __all__ = [ "one_hot", "slice_channels", @@ -35,7 +40,9 @@ ] -def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor: +def one_hot( + labels: DataObjects.Images, num_classes: int, dtype: Union[DtypeLike, torch.dtype] = torch.float, dim: int = 1 +) -> DataObjects.Images: """ For every value v in `labels`, the value in the output will be either 1 or 0. Each vector along the `dim`-th dimension has the "one-hot" format, i.e., it has a total length of `num_classes`, @@ -69,21 +76,29 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f print(out.shape) # torch.Size([2, 2, 2, 2, 2]) """ + dtype = dtype_convert(dtype, type(labels)) # if `dim` is bigger, add singleton dim at the end if labels.ndim < dim + 1: shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) - labels = torch.reshape(labels, shape) + labels = labels.reshape(shape) sh = list(labels.shape) if sh[dim] != 1: raise AssertionError("labels should have a channel with length equal to one.") - sh[dim] = num_classes + if isinstance(labels, np.ndarray): + labels = np.eye(num_classes)[labels.astype(np.longlong)] # adds one hot to end + labels = labels.astype(dtype) + labels = labels.squeeze(dim) # remove singleton + labels = np.moveaxis(labels, -1, dim) # move one hot dim to desired index + + else: + sh[dim] = num_classes - o = torch.zeros(size=sh, dtype=dtype, device=labels.device) - labels = o.scatter_(dim=dim, index=labels.long(), value=1) + o = torch.zeros(size=sh, dtype=dtype, device=labels.device) + labels = o.scatter_(dim=dim, index=labels.long(), value=1) return labels diff --git a/tests/test_to_onehot.py b/tests/test_to_onehot.py index c3e373955d..c6c3264f74 100644 --- a/tests/test_to_onehot.py +++ b/tests/test_to_onehot.py @@ -16,6 +16,8 @@ from parameterized import parameterized from monai.networks import one_hot +from monai.utils.misc import dtype_convert +from tests.utils import TEST_NDARRAYS TEST_CASE_1 = [ # single channel 2D, batch 3, shape (2, 1, 2, 2) {"labels": torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), "num_classes": 3}, @@ -44,16 +46,30 @@ class TestToOneHot(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_data, expected_shape, expected_result=None): - result = one_hot(**input_data) - self.assertEqual(result.shape, expected_shape) - if expected_result is not None: - self.assertTrue(np.allclose(expected_result, result.numpy())) - - if "dtype" in input_data: - self.assertEqual(result.dtype, input_data["dtype"]) - else: - # by default, expecting float type - self.assertEqual(result.dtype, torch.float) + results = [] + for p in TEST_NDARRAYS: + input_data_mod = {k: p(v) if isinstance(v, torch.Tensor) else v for k, v in input_data.items()} + orig_dtype = input_data_mod["labels"].dtype + result = one_hot(**input_data_mod) + self.assertEqual(result.shape, expected_shape) + self.assertEqual(type(result), type(input_data_mod["labels"])) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, input_data_mod["labels"].device) + result = result.cpu().numpy() + + if expected_result is not None: + self.assertTrue(np.allclose(expected_result, result)) + + self.assertEqual(input_data_mod["labels"].dtype, orig_dtype) + if "dtype" in input_data: + self.assertEqual(result.dtype, dtype_convert(input_data_mod["dtype"], type(result))) + else: + # by default, expecting float type + self.assertEqual(result.dtype, dtype_convert(torch.float, type(result))) + + results.append(result) + if len(results) > 1: + np.testing.assert_allclose(results[0], results[-1]) if __name__ == "__main__": From f3e9b6afcf835b01577093bd05b1d321760bb493 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 23 Jul 2021 18:14:37 +0100 Subject: [PATCH 153/176] loadsa stuff Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/losses/dice.py | 4 +- monai/losses/focal_loss.py | 2 +- monai/losses/tversky.py | 2 +- monai/networks/utils.py | 17 +-- monai/transforms/intensity/array.py | 16 ++- monai/transforms/intensity/dictionary.py | 10 +- monai/transforms/post/array.py | 129 +++++++++++------- monai/transforms/utils.py | 19 ++- .../test_keep_largest_connected_component.py | 31 +++-- tests/test_mean_ensemble.py | 2 +- tests/test_rand_histogram_shift.py | 51 ++++--- tests/test_rand_histogram_shiftd.py | 76 +++++++---- tests/test_vote_ensemble.py | 16 ++- tests/test_vote_ensembled.py | 9 +- 14 files changed, 242 insertions(+), 142 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index aa58490136..e8e6039bc7 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -130,7 +130,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - target = one_hot(target, num_classes=n_pred_ch) + target = one_hot(target, num_classes=n_pred_ch) # type: ignore if not self.include_background: if n_pred_ch == 1: @@ -306,7 +306,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - target = one_hot(target, num_classes=n_pred_ch) + target = one_hot(target, num_classes=n_pred_ch) # type: ignore if not self.include_background: if n_pred_ch == 1: diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index b4b3698e5b..f4c361168e 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -96,7 +96,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - target = one_hot(target, num_classes=n_pred_ch) + target = one_hot(target, num_classes=n_pred_ch) # type: ignore if not self.include_background: if n_pred_ch == 1: diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 1d75b9e8cc..8170920b47 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -120,7 +120,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - target = one_hot(target, num_classes=n_pred_ch) + target = one_hot(target, num_classes=n_pred_ch) # type: ignore if not self.include_background: if n_pred_ch == 1: diff --git a/monai/networks/utils.py b/monai/networks/utils.py index a9820ff55c..dac2f6039d 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -89,16 +89,17 @@ def one_hot( raise AssertionError("labels should have a channel with length equal to one.") if isinstance(labels, np.ndarray): - labels = np.eye(num_classes)[labels.astype(np.longlong)] # adds one hot to end - labels = labels.astype(dtype) - labels = labels.squeeze(dim) # remove singleton - labels = np.moveaxis(labels, -1, dim) # move one hot dim to desired index + label_np: np.ndarray + label_np = np.eye(num_classes)[labels.astype(np.longlong)] # adds one hot to end + label_np = label_np.astype(dtype) # type: ignore + label_np = label_np.squeeze(dim) # remove singleton + label_np = np.moveaxis(label_np, -1, dim) # move one hot dim to desired index + return label_np - else: - sh[dim] = num_classes + sh[dim] = num_classes - o = torch.zeros(size=sh, dtype=dtype, device=labels.device) - labels = o.scatter_(dim=dim, index=labels.long(), value=1) + o = torch.zeros(size=sh, dtype=dtype, device=labels.device) # type: ignore + labels = o.scatter_(dim=dim, index=labels.long(), value=1) return labels diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index a81d4dff78..dfc8baa137 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1164,7 +1164,7 @@ def __call__(self, img: np.ndarray): return GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(img) -class RandHistogramShift(RandomizableTransform): +class RandHistogramShift(RandomizableTransform, NumpyTransform): """ Apply random nonlinear transform to the image's intensity histogram. @@ -1199,16 +1199,20 @@ def randomize(self, data: Optional[Any] = None) -> None: self.floating_control_points[i - 1], self.floating_control_points[i + 1] ) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: self.randomize() if not self._do_transform: return img - img_min, img_max = img.min(), img.max() + + img_np: np.ndarray + img_np, orig_type, orig_device = convert_data_type(img, np.ndarray) # type: ignore + img_min, img_max = img_np.min(), img_np.max() reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min - return np.asarray( - np.interp(img, reference_control_points_scaled, floating_control_points_scaled), dtype=img.dtype - ) + + out_np = np.interp(img_np, reference_control_points_scaled, floating_control_points_scaled) + out, *_ = convert_data_type(out_np, orig_type, orig_device, dtype=img.dtype) + return out class RandGibbsNoise(TorchTransform, NumpyTransform, RandomizableTransform): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 5106ab0350..74b491dc99 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1027,11 +1027,15 @@ def __call__(self, data: DataObjects.Mapping) -> DataObjects.Dict: if not self._do_transform: return d for key in self.key_iterator(d): - img_min, img_max = d[key].min(), d[key].max() + img_np: np.ndarray + img_np, orig_type, orig_device = convert_data_type(d[key], np.ndarray) # type: ignore + img_min, img_max = img_np.min(), img_np.max() reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min - dtype = d[key].dtype - d[key] = np.interp(d[key], reference_control_points_scaled, floating_control_points_scaled).astype(dtype) + + img_np = np.interp(img_np, reference_control_points_scaled, floating_control_points_scaled) + out, *_ = convert_data_type(img_np, orig_type, orig_device, dtype=d[key].dtype) + d[key] = out return d diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 54e7c59bd8..47c0fae741 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -23,7 +23,7 @@ from monai.networks import one_hot from monai.networks.layers import GaussianFilter -from monai.transforms.transform import NumpyTransform, TorchTransform, Transform +from monai.transforms.transform import NumpyTransform, TorchTransform from monai.transforms.utils import get_largest_connected_component_mask from monai.utils import ensure_tuple from monai.utils.enums import DataObjects @@ -40,7 +40,24 @@ ] -class Activations(TorchTransform): +def _sigmoid(z): + if isinstance(z, torch.Tensor): + return torch.sigmoid(z) + return 1 / (1 + np.exp(-z)) + + +def _softmax(z, dim): + if isinstance(z, torch.Tensor): + return torch.softmax(z, dim=dim) + + max = np.max(z, axis=dim, keepdims=True) # returns max of each row and keeps same dims + e_x = np.exp(z - max) # subtracts each row with its max value + sum = np.sum(e_x, axis=dim, keepdims=True) # returns sum of each row and keeps same dims + f_x = e_x / sum + return f_x + + +class Activations(TorchTransform, NumpyTransform): """ Add activation operations to the model output, typically `Sigmoid` or `Softmax`. @@ -91,24 +108,27 @@ def __call__( if other is not None and not callable(other): raise TypeError(f"other must be None or callable but is {type(other).__name__}.") - img_t: torch.Tensor - img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=torch.float32) # type: ignore - # convert to float as activation must operate on float tensor if sigmoid or self.sigmoid: - img_t = torch.sigmoid(img_t) + img = _sigmoid(img) if softmax or self.softmax: - img_t = torch.softmax(img_t, dim=0) + img = _softmax(img, dim=0) act_func = self.other if other is None else other if act_func is not None: - img_t = act_func(img_t) + try: + img = act_func(img) + except TypeError as te: + # callable only works on torch.Tensors + if "must be Tensor, not numpy.ndarray" in str(te): + img, *_ = convert_data_type(img, torch.Tensor) + img = act_func(img) + img, *_ = convert_data_type(img, np.ndarray) - out, *_ = convert_data_type(img_t, orig_type, orig_device) - return out + return img -class AsDiscrete(TorchTransform): +class AsDiscrete(TorchTransform, NumpyTransform): """ Execute after model forward to transform model output to discrete values. It can complete below operations: @@ -170,22 +190,22 @@ def __call__( Defaults to ``self.logit_thresh``. """ - img_t: torch.Tensor - img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore - if argmax or self.argmax: - img_t = torch.argmax(img_t, dim=0, keepdim=True) + if isinstance(img, torch.Tensor): + img = torch.argmax(img, dim=0, keepdim=True) + else: + img = np.argmax(img, axis=0)[None] if to_onehot or self.to_onehot: _nclasses = self.n_classes if n_classes is None else n_classes if not isinstance(_nclasses, int): raise AssertionError("One of self.n_classes or n_classes must be an integer") - img_t = one_hot(img_t, num_classes=_nclasses, dim=0) + img = one_hot(img, num_classes=_nclasses, dim=0) if threshold_values or self.threshold_values: - img_t = img_t >= (logit_thresh or self.logit_thresh) + img = img >= (logit_thresh or self.logit_thresh) - out, *_ = convert_data_type(img_t, orig_type, orig_device, dtype=float) + out, *_ = convert_data_type(img, dtype=torch.float32) return out @@ -255,6 +275,11 @@ def __init__( self.independent = independent self.connectivity = connectivity + @staticmethod + def _astype(x, dtype=torch.uint8): + x, *_ = convert_data_type(x, dtype=dtype) + return x + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ Args: @@ -263,44 +288,43 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: Returns: A PyTorch Tensor with shape (C, spatial_dim1[, spatial_dim2, ...]). """ - img_t: torch.Tensor - img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor) # type: ignore - - if img_t.shape[0] == 1: - img_t = torch.squeeze(img_t, dim=0) + if img.shape[0] == 1: + img = img.squeeze(0) if self.independent: for i in self.applied_labels: - foreground = (img_t == i).type(torch.uint8) + foreground = self._astype(img == i) mask = get_largest_connected_component_mask(foreground, self.connectivity) - img_t[foreground != mask] = 0 + img[foreground != mask] = 0 else: - foreground = torch.zeros_like(img_t) + foreground = torch.zeros_like(img) if isinstance(img, torch.Tensor) else np.zeros_like(img) for i in self.applied_labels: - foreground += (img_t == i).type(torch.uint8) + foreground += self._astype(img == i) mask = get_largest_connected_component_mask(foreground, self.connectivity) - img_t[foreground != mask] = 0 + img[foreground != mask] = 0 - output = torch.unsqueeze(img_t, dim=0) + output = img[None] else: # one-hot data is assumed to have binary value in each channel if self.independent: for i in self.applied_labels: - foreground = img_t[i, ...].type(torch.uint8) + foreground = self._astype(img[i, ...]) mask = get_largest_connected_component_mask(foreground, self.connectivity) - img_t[i, ...][foreground != mask] = 0 + img[i, ...][foreground != mask] = 0 else: - applied_img = img_t[self.applied_labels, ...].type(torch.uint8) - foreground = torch.any(applied_img, dim=0) + applied_img = self._astype(img[self.applied_labels, ...]) + foreground = applied_img.any(0) mask = get_largest_connected_component_mask(foreground, self.connectivity) - background_mask = torch.unsqueeze(foreground != mask, dim=0) - background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=0) + background_mask = (foreground != mask)[None] + if isinstance(background_mask, torch.Tensor): + background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=0) + else: + background_mask = np.repeat(background_mask, len(self.applied_labels), axis=0) applied_img[background_mask] = 0 - img_t[self.applied_labels, ...] = applied_img.type(img_t.type()) # type: ignore - output = img_t + img[self.applied_labels, ...] = self._astype(applied_img, img.dtype) + output = img - out, *_ = convert_data_type(output, orig_type, orig_device) - return out + return output class LabelToContour(TorchTransform): @@ -418,7 +442,7 @@ def __call__(self, img: Union[Sequence[DataObjects.Images], DataObjects.Images]) return img_.mean(0) # type: ignore -class VoteEnsemble(Transform): +class VoteEnsemble(TorchTransform, NumpyTransform): """ Execute vote ensemble on the input data. The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], @@ -442,26 +466,35 @@ def __init__(self, num_classes: Optional[int] = None) -> None: self.num_classes = num_classes def __call__(self, img: Union[Sequence[DataObjects.Images], DataObjects.Images]) -> DataObjects.Images: - img_ = ( - torch.stack([torch.as_tensor(i) for i in img]) if isinstance(img, (tuple, list)) else torch.as_tensor(img) - ) + if isinstance(img, (torch.Tensor, np.ndarray)): + img_ = img + elif isinstance(img[0], torch.Tensor): + img_ = torch.stack(img) # type: ignore + else: + img_ = np.stack(img) + if self.num_classes is not None: has_ch_dim = True - if img_.ndimension() > 1 and img_.shape[1] > 1: + if img_.ndim > 1 and img_.shape[1] > 1: warnings.warn("no need to specify num_classes for One-Hot format data.") else: - if img_.ndimension() == 1: + if img_.ndim == 1: # if no channel dim, need to remove channel dim after voting has_ch_dim = False img_ = one_hot(img_, self.num_classes, dim=1) - img_ = torch.mean(img_.float(), dim=0) + img_ = torch.mean(img_.float(), dim=0) if isinstance(img_, torch.Tensor) else np.mean(img_, axis=0) if self.num_classes is not None: # if not One-Hot, use "argmax" to vote the most common class - return torch.argmax(img_, dim=0, keepdim=has_ch_dim) + if isinstance(img_, torch.Tensor): + return torch.argmax(img_, dim=0, keepdim=has_ch_dim) + else: + img_ = np.argmax(img_, axis=0) + img_ = np.array(img_) if np.isscalar(img_) else img_ # numpy returns scalar if input was 1d + return img_[None] if has_ch_dim else img_ # for One-Hot data, round the float number to 0 or 1 - return torch.round(img_) + return torch.round(img_) if isinstance(img_, torch.Tensor) else np.round(img_) class ProbNMS(TorchTransform): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index cb1da8fd24..c646d7eb97 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -749,7 +749,9 @@ def generate_spatial_bounding_box( return box_start, box_end -def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Optional[int] = None) -> torch.Tensor: +def get_largest_connected_component_mask( + img: DataObjects.Images, connectivity: Optional[int] = None +) -> DataObjects.Images: """ Gets the largest connected component mask of an image. @@ -759,13 +761,16 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. """ - img_arr = img.detach().cpu().numpy() - largest_cc = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype) - img_arr = measure.label(img_arr, connectivity=connectivity) - if img_arr.max() != 0: - largest_cc[...] = img_arr == (np.argmax(np.bincount(img_arr.flat)[1:]) + 1) + img_np: np.ndarray + img_np, orig_type, orig_device = convert_data_type(img, np.ndarray) # type: ignore + largest_cc = np.zeros_like(img_np) + img_np = measure.label(img_np, connectivity=connectivity) + + if img_np.max() != 0: + largest_cc[...] = img_np == (np.argmax(np.bincount(img_np.flat)[1:]) + 1) - return torch.as_tensor(largest_cc, device=img.device) + out, *_ = convert_data_type(largest_cc, orig_type, orig_device) + return out def get_extreme_points( diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index a8835329ba..6eaa61897d 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -11,10 +11,12 @@ import unittest +import numpy as np import torch from parameterized import parameterized from monai.transforms import KeepLargestConnectedComponent +from tests.utils import TEST_NDARRAYS grid_1 = torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]) grid_2 = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]) @@ -322,23 +324,24 @@ class TestKeepLargestConnectedComponent(unittest.TestCase): @parameterized.expand(VALID_CASES) - def test_correct_results(self, _, args, tensor, expected): - converter = KeepLargestConnectedComponent(**args) - if torch.cuda.is_available(): - result = converter(tensor.clone().cuda()) - assert torch.allclose(result, expected.cuda()) - else: - result = converter(tensor.clone()) - assert torch.allclose(result, expected) + def test_correct_results(self, _, args, img, expected): + for p in TEST_NDARRAYS: + converter = KeepLargestConnectedComponent(**args) + im = p(img.clone()) + result = converter(im) + self.assertEqual(type(im), type(result)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im.device) + result = result.cpu() + np.testing.assert_allclose(result, expected) @parameterized.expand(INVALID_CASES) - def test_raise_exception(self, _, args, tensor, expected_error): + def test_raise_exception(self, _, args, img, expected_error): with self.assertRaises(expected_error): - converter = KeepLargestConnectedComponent(**args) - if torch.cuda.is_available(): - _ = converter(tensor.clone().cuda()) - else: - _ = converter(tensor.clone()) + for p in TEST_NDARRAYS: + converter = KeepLargestConnectedComponent(**args) + im = p(img) + _ = converter(im) if __name__ == "__main__": diff --git a/tests/test_mean_ensemble.py b/tests/test_mean_ensemble.py index 6bc180bb1e..a4595fe1ac 100644 --- a/tests/test_mean_ensemble.py +++ b/tests/test_mean_ensemble.py @@ -58,7 +58,7 @@ class TestMeanEnsemble(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_value(self, input_param, img, expected_value): - for i, p in enumerate(TEST_NDARRAYS): + for p in TEST_NDARRAYS: if isinstance(img, list): im = [p(i) for i in img] im_type = type(im[0]) diff --git a/tests/test_rand_histogram_shift.py b/tests/test_rand_histogram_shift.py index b258cc5a7e..59c8e51128 100644 --- a/tests/test_rand_histogram_shift.py +++ b/tests/test_rand_histogram_shift.py @@ -12,35 +12,48 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandHistogramShift - -TEST_CASES = [ - [ - {"num_control_points": 5, "prob": 0.0}, - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - {"num_control_points": 5, "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)}, - np.array([[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]]), - ], - [ - {"num_control_points": (5, 20), "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)}, - np.array([[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]]), - ], -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"num_control_points": 5, "prob": 0.0}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + np.arange(8).reshape((1, 2, 2, 2)), + ] + ) + TESTS.append( + [ + {"num_control_points": 5, "prob": 0.9}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32))}, + np.array([[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]]), + ] + ) + TESTS.append( + [ + {"num_control_points": (5, 20), "prob": 0.9}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32))}, + np.array([[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]]), + ] + ) class TestRandHistogramShift(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_histogram_shift(self, input_param, input_data, expected_val): g = RandHistogramShift(**input_param) g.set_random_state(123) result = g(**input_data) + im_in = input_data["img"] + self.assertEqual(type(result), type(im_in)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im_in.device) + result = result.cpu() np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) diff --git a/tests/test_rand_histogram_shiftd.py b/tests/test_rand_histogram_shiftd.py index 806e4f5cf2..a1b11a0905 100644 --- a/tests/test_rand_histogram_shiftd.py +++ b/tests/test_rand_histogram_shiftd.py @@ -12,47 +12,67 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandHistogramShiftD +from tests.utils import TEST_NDARRAYS -TEST_CASES = [ - [ - {"keys": ("img",), "num_control_points": 5, "prob": 0.0}, - {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - ], - [ - {"keys": ("img",), "num_control_points": 5, "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - { - "img": np.array( - [[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]] - ), - "seg": np.ones(8).reshape((1, 2, 2, 2)), - }, - ], - [ - {"keys": ("img",), "num_control_points": (5, 20), "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - { - "img": np.array( - [[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]] - ), - "seg": np.ones(8).reshape((1, 2, 2, 2)), - }, - ], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ("img",), "num_control_points": 5, "prob": 0.0}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "seg": p(np.ones(8).reshape((1, 2, 2, 2)))}, + {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, + ] + ) + TESTS.append( + [ + {"keys": ("img",), "num_control_points": 5, "prob": 0.9}, + { + "img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)), + "seg": p(np.ones(8).reshape((1, 2, 2, 2))), + }, + { + "img": np.array( + [[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]] + ), + "seg": np.ones(8).reshape((1, 2, 2, 2)), + }, + ] + ) + TESTS.append( + [ + {"keys": ("img",), "num_control_points": (5, 20), "prob": 0.9}, + { + "img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)), + "seg": p(np.ones(8).reshape((1, 2, 2, 2))), + }, + { + "img": np.array( + [[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]] + ), + "seg": np.ones(8).reshape((1, 2, 2, 2)), + }, + ] + ) class TestRandHistogramShiftD(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_histogram_shiftd(self, input_param, input_data, expected_val): g = RandHistogramShiftD(**input_param) g.set_random_state(123) res = g(input_data) for key in res: result = res[key] + im_in = input_data[key] + self.assertEqual(type(result), type(im_in)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im_in.device) + result = result.cpu() + expected = expected_val[key] if isinstance(expected_val, dict) else expected_val np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) diff --git a/tests/test_vote_ensemble.py b/tests/test_vote_ensemble.py index 7abe120441..ba48ed1825 100644 --- a/tests/test_vote_ensemble.py +++ b/tests/test_vote_ensemble.py @@ -11,6 +11,7 @@ import unittest +import numpy as np import torch from parameterized import parameterized @@ -80,10 +81,21 @@ class TestVoteEnsemble(unittest.TestCase): @parameterized.expand(TESTS) def test_value(self, in_type, input_param, img, expected_value): - result = VoteEnsemble(**input_param)([in_type(i) for i in img]) + if isinstance(img, list): + im = [in_type(i) for i in img] + im_type = type(im[0]) + im_device = im[0].device if isinstance(im[0], torch.Tensor) else None + else: + im = in_type(img) + im_type = type(im) + im_device = im.device if isinstance(im, torch.Tensor) else None + + result = VoteEnsemble(**input_param)(im) + self.assertEqual(im_type, type(result)) if isinstance(result, torch.Tensor): + self.assertEqual(result.device, im_device) result = result.cpu() - torch.testing.assert_allclose(result, expected_value) + np.testing.assert_allclose(result, expected_value) def test_cuda_value(self): img = torch.stack( diff --git a/tests/test_vote_ensembled.py b/tests/test_vote_ensembled.py index ab8daa1e5b..972469d1db 100644 --- a/tests/test_vote_ensembled.py +++ b/tests/test_vote_ensembled.py @@ -11,6 +11,7 @@ import unittest +import numpy as np import torch from parameterized import parameterized @@ -92,8 +93,12 @@ def test_value(self, in_type, input_param, img, expected_value): for k, v in img.items(): img[k] = in_type(v) result = VoteEnsembled(**input_param)(img)["output"] - result = result.cpu() if isinstance(result, torch.Tensor) else result - torch.testing.assert_allclose(result, expected_value) + in_im = img["pred0"] if "pred0" in img else img["output"] + self.assertEqual(type(result), type(in_im)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, in_im.device) + result = result.cpu() + np.testing.assert_allclose(result, expected_value) def test_cuda_value(self): img = torch.stack( From 57fc771a88d62e88c6ad6bb8773334c59af897ee Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 28 Jul 2021 10:06:08 +0100 Subject: [PATCH 154/176] fg_bg_to_indices Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 1 - monai/transforms/intensity/dictionary.py | 2 +- monai/transforms/utility/array.py | 12 +-- monai/transforms/utils.py | 3 + tests/test_fg_bg_to_indices.py | 98 +++++++++++++---------- tests/test_fg_bg_to_indicesd.py | 99 ++++++++++++++++-------- 6 files changed, 134 insertions(+), 81 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index ec36ecc7d1..b7734890c2 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -30,7 +30,6 @@ from monai.utils import ( PT_BEFORE_1_7, InvalidPyTorchVersionError, - dtype_torch_to_numpy, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 34f73cb1ea..7b90505705 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -42,7 +42,7 @@ ThresholdIntensity, ) from monai.transforms.transform import MapTransform, RandomizableTransform -from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple +from monai.utils import ensure_tuple_rep, ensure_tuple_size, fall_back_tuple from monai.utils.enums import DataObjects from monai.utils.misc import convert_data_type diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index b45fbec0b1..82d8ac3761 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -25,6 +25,7 @@ from monai.config import DtypeLike from monai.transforms.transform import NumpyTransform, Randomizable, TorchTransform, Transform from monai.transforms.utils import ( + _unravel_index, convert_to_numpy, convert_to_tensor, extreme_points_to_image, @@ -702,7 +703,7 @@ def __call__( return data -class FgBgToIndices(Transform): +class FgBgToIndices(TorchTransform, NumpyTransform): def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None: """ Compute foreground and background of the input label data, return the indices. @@ -722,8 +723,8 @@ def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence def __call__( self, - label: np.ndarray, - image: Optional[np.ndarray] = None, + label: DataObjects.Images, + image: Optional[DataObjects.Images] = None, output_shape: Optional[Sequence[int]] = None, ) -> Tuple[DataObjects.Images, DataObjects.Images]: """ @@ -738,8 +739,9 @@ def __call__( output_shape = self.output_shape fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) if output_shape is not None: - fg_indices = np.stack([np.unravel_index(i, output_shape) for i in fg_indices]) - bg_indices = np.stack([np.unravel_index(i, output_shape) for i in bg_indices]) + stack = np.stack if isinstance(label, np.ndarray) else torch.stack + fg_indices = stack([_unravel_index(i, output_shape) for i in fg_indices]) # type: ignore + bg_indices = stack([_unravel_index(i, output_shape) for i in bg_indices]) # type: ignore return fg_indices, bg_indices diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index c646d7eb97..5365c3f8a7 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -282,6 +282,9 @@ def _nonzero(x): fg_indices = _nonzero(label_flat) if image is not None: img_flat = (image > image_threshold).any(0).ravel() + img_flat, *_ = convert_data_type( + img_flat, type(label), device=label.device if isinstance(label, torch.Tensor) else None + ) bg_indices = _nonzero(img_flat & ~label_flat) else: bg_indices = _nonzero(~label_flat) diff --git a/tests/test_fg_bg_to_indices.py b/tests/test_fg_bg_to_indices.py index 98626c7028..e8451ba334 100644 --- a/tests/test_fg_bg_to_indices.py +++ b/tests/test_fg_bg_to_indices.py @@ -12,57 +12,73 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import FgBgToIndices +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"image_threshold": 0.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - None, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: -TEST_CASE_2 = [ - {"image_threshold": 0.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] - -TEST_CASE_3 = [ - {"image_threshold": 1.0, "output_shape": None}, - np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] - -TEST_CASE_4 = [ - {"image_threshold": 1.0, "output_shape": None}, - np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] - -TEST_CASE_5 = [ - {"image_threshold": 1.0, "output_shape": [3, 3]}, - np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), - np.array([[0, 0], [2, 2]]), -] + TESTS.append( + [ + {"image_threshold": 0.0, "output_shape": None}, + p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + None, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 4, 8]), + ] + ) + TESTS.append( + [ + {"image_threshold": 0.0, "output_shape": None}, + p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + q(np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])), + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + TESTS.append( + [ + {"image_threshold": 1.0, "output_shape": None}, + p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + q(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + TESTS.append( + [ + {"image_threshold": 1.0, "output_shape": None}, + p(np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]])), + q(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + TESTS.append( + [ + {"image_threshold": 1.0, "output_shape": [3, 3]}, + p(np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]])), + q(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), + np.array([[0, 0], [2, 2]]), + ] + ) class TestFgBgToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, label, image, expected_fg, expected_bg): fg_indices, bg_indices = FgBgToIndices(**input_data)(label, image) - np.testing.assert_allclose(fg_indices, expected_fg) - np.testing.assert_allclose(bg_indices, expected_bg) + for indices, expected in zip((fg_indices, bg_indices), (expected_fg, expected_bg)): + self.assertEqual(type(indices), type(label)) + if isinstance(indices, torch.Tensor): + self.assertEqual(indices.device, label.device) + indices = indices.cpu() + np.testing.assert_allclose(indices, expected) if __name__ == "__main__": diff --git a/tests/test_fg_bg_to_indicesd.py b/tests/test_fg_bg_to_indicesd.py index ce6ca30f1b..0a47bda6a0 100644 --- a/tests/test_fg_bg_to_indicesd.py +++ b/tests/test_fg_bg_to_indicesd.py @@ -12,52 +12,85 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import FgBgToIndicesd +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: -TEST_CASE_2 = [ - {"keys": "label", "image_key": "image", "image_threshold": 0.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS.append( + [ + {"keys": "label", "image_key": None, "image_threshold": 0.0, "output_shape": None}, + {"label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]))}, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 4, 8]), + ] + ) -TEST_CASE_3 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 0.0, "output_shape": None}, + { + "label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + "image": q(np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])), + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) -TEST_CASE_4 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, - {"label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] + TESTS.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, + { + "label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + "image": q(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) -TEST_CASE_5 = [ - {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": [3, 3]}, - {"label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])}, - np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), - np.array([[0, 0], [2, 2]]), -] + TESTS.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": None}, + { + "label": p(np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]])), + "image": q(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + + TESTS.append( + [ + {"keys": "label", "image_key": "image", "image_threshold": 1.0, "output_shape": [3, 3]}, + { + "label": p(np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]])), + "image": q(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + }, + np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]), + np.array([[0, 0], [2, 2]]), + ] + ) class TestFgBgToIndicesd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, data, expected_fg, expected_bg): result = FgBgToIndicesd(**input_data)(data) - np.testing.assert_allclose(result["label_fg_indices"], expected_fg) - np.testing.assert_allclose(result["label_bg_indices"], expected_bg) + for key, expected in zip(("fg", "bg"), (expected_fg, expected_bg)): + r = result[f"label_{key}_indices"] + self.assertEqual(type(r), type(data["label"])) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, data["label"].device) + r = r.cpu() + np.testing.assert_allclose(r, expected) if __name__ == "__main__": From b940325f0c15f8e3ca085c1401f26e0f4026c109 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 28 Jul 2021 10:18:29 +0100 Subject: [PATCH 155/176] RandCoarseDropout Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 4 +- tests/test_rand_coarse_dropout.py | 59 ++++++++++++------- tests/test_rand_coarse_dropoutd.py | 88 +++++++++++++++++------------ 3 files changed, 93 insertions(+), 58 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index b7734890c2..ae3c0f81f0 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1682,7 +1682,7 @@ def _to_numpy(self, img: Union[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, t return img, torch.device("cpu") -class RandCoarseDropout(RandomizableTransform): +class RandCoarseDropout(RandomizableTransform, TorchTransform, NumpyTransform): """ Randomly coarse dropout regions in the image, then fill in the rectangular regions with specified value. Refer to: https://arxiv.org/abs/1708.04552 and: @@ -1738,7 +1738,7 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, size) self.hole_coords.append((slice(None),) + get_random_patch(img_size, valid_size, self.R)) - def __call__(self, img: np.ndarray): + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: self.randomize(img.shape[1:]) if self._do_transform: for h in self.hole_coords: diff --git a/tests/test_rand_coarse_dropout.py b/tests/test_rand_coarse_dropout.py index 235a391567..a7c38a6c66 100644 --- a/tests/test_rand_coarse_dropout.py +++ b/tests/test_rand_coarse_dropout.py @@ -12,41 +12,54 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandCoarseDropout from monai.utils import fall_back_tuple +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - {"holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "prob": 1.0}, - np.random.randint(0, 2, size=[3, 3, 3, 4]), - (3, 3, 3, 4), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "prob": 1.0}, + p(np.random.randint(0, 2, size=[3, 3, 3, 4])), + (3, 3, 3, 4), + ] + ) -TEST_CASE_1 = [ - {"holes": 1, "spatial_size": [1, 2, 3], "fill_value": 5, "max_holes": 5, "prob": 1.0}, - np.random.randint(0, 2, size=[3, 3, 3, 4]), - (3, 3, 3, 4), -] + TESTS.append( + [ + {"holes": 1, "spatial_size": [1, 2, 3], "fill_value": 5, "max_holes": 5, "prob": 1.0}, + p(np.random.randint(0, 2, size=[3, 3, 3, 4])), + (3, 3, 3, 4), + ] + ) -TEST_CASE_2 = [ - {"holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "max_spatial_size": [4, 4, 3], "prob": 1.0}, - np.random.randint(0, 2, size=[3, 3, 3, 4]), - (3, 3, 3, 4), -] + TESTS.append( + [ + {"holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "max_spatial_size": [4, 4, 3], "prob": 1.0}, + p(np.random.randint(0, 2, size=[3, 3, 3, 4])), + (3, 3, 3, 4), + ] + ) -TEST_CASE_3 = [ - {"holes": 2, "spatial_size": [2, -1, 2], "fill_value": 5, "max_spatial_size": [4, 4, -1], "prob": 1.0}, - np.random.randint(0, 2, size=[3, 3, 3, 4]), - (3, 3, 3, 4), -] + TESTS.append( + [ + {"holes": 2, "spatial_size": [2, -1, 2], "fill_value": 5, "max_spatial_size": [4, 4, -1], "prob": 1.0}, + p(np.random.randint(0, 2, size=[3, 3, 3, 4])), + (3, 3, 3, 4), + ] + ) class TestRandCoarseDropout(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, input_param, input_data, expected_shape): dropout = RandCoarseDropout(**input_param) result = dropout(input_data) + np.testing.assert_equal(result.shape, expected_shape) holes = input_param.get("holes") max_holes = input_param.get("max_holes") spatial_size = fall_back_tuple(input_param.get("spatial_size"), input_data.shape[1:]) @@ -60,6 +73,10 @@ def test_value(self, input_param, input_data, expected_shape): for h in dropout.hole_coords: data = result[h] + self.assertEqual(type(data), type(input_data)) + if isinstance(data, torch.Tensor): + self.assertEqual(data.device, input_data.device) + data = data.cpu() np.testing.assert_allclose(data, input_param.get("fill_value", 0)) if max_spatial_size is None: self.assertTupleEqual(data.shape[1:], tuple(spatial_size)) diff --git a/tests/test_rand_coarse_dropoutd.py b/tests/test_rand_coarse_dropoutd.py index d189a80f56..1511590ed6 100644 --- a/tests/test_rand_coarse_dropoutd.py +++ b/tests/test_rand_coarse_dropoutd.py @@ -12,55 +12,69 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandCoarseDropoutd from monai.utils import fall_back_tuple +from tests.utils import TEST_NDARRAYS -TEST_CASE_0 = [ - {"keys": "img", "holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "prob": 1.0}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, - (3, 3, 3, 4), -] +TESTS = [] +for p in TEST_NDARRAYS: -TEST_CASE_1 = [ - {"keys": "img", "holes": 1, "spatial_size": [1, 2, 3], "fill_value": 5, "max_holes": 5, "prob": 1.0}, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, - (3, 3, 3, 4), -] + TESTS.append( + [ + {"keys": "img", "holes": 2, "spatial_size": [2, 2, 2], "fill_value": 5, "prob": 1.0}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 4]))}, + (3, 3, 3, 4), + ] + ) -TEST_CASE_2 = [ - { - "keys": "img", - "holes": 2, - "spatial_size": [2, 2, 2], - "fill_value": 5, - "max_spatial_size": [4, 4, 3], - "prob": 1.0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, - (3, 3, 3, 4), -] + TESTS.append( + [ + {"keys": "img", "holes": 1, "spatial_size": [1, 2, 3], "fill_value": 5, "max_holes": 5, "prob": 1.0}, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 4]))}, + (3, 3, 3, 4), + ] + ) -TEST_CASE_3 = [ - { - "keys": "img", - "holes": 2, - "spatial_size": [2, -1, 2], - "fill_value": 5, - "max_spatial_size": [4, 4, -1], - "prob": 1.0, - }, - {"img": np.random.randint(0, 2, size=[3, 3, 3, 4])}, - (3, 3, 3, 4), -] + TESTS.append( + [ + { + "keys": "img", + "holes": 2, + "spatial_size": [2, 2, 2], + "fill_value": 5, + "max_spatial_size": [4, 4, 3], + "prob": 1.0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 4]))}, + (3, 3, 3, 4), + ] + ) + + TESTS.append( + [ + { + "keys": "img", + "holes": 2, + "spatial_size": [2, -1, 2], + "fill_value": 5, + "max_spatial_size": [4, 4, -1], + "prob": 1.0, + }, + {"img": p(np.random.randint(0, 2, size=[3, 3, 3, 4]))}, + (3, 3, 3, 4), + ] + ) class TestRandCoarseDropoutd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, input_param, input_data, expected_shape): dropout = RandCoarseDropoutd(**input_param) result = dropout(input_data)["img"] + np.testing.assert_equal(result.shape, expected_shape) holes = input_param.get("holes") max_holes = input_param.get("max_holes") spatial_size = fall_back_tuple(input_param.get("spatial_size"), input_data["img"].shape[1:]) @@ -74,6 +88,10 @@ def test_value(self, input_param, input_data, expected_shape): for h in dropout.hole_coords: data = result[h] + self.assertEqual(type(data), type(input_data["img"])) + if isinstance(data, torch.Tensor): + self.assertEqual(data.device, input_data["img"].device) + data = data.cpu() np.testing.assert_allclose(data, input_param.get("fill_value", 0)) if max_spatial_size is None: self.assertTupleEqual(data.shape[1:], tuple(spatial_size)) From 7f6f05fbfcc7a1d5c88acf07e80b480dcecefd4a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 28 Jul 2021 12:50:54 +0100 Subject: [PATCH 156/176] ConvertToMultiChannelBasedOnBratsClasses Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 14 ++++---- tests/test_convert_to_multi_channel.py | 44 +++++++++++++++++-------- tests/test_convert_to_multi_channeld.py | 35 +++++++++++++++----- 3 files changed, 65 insertions(+), 28 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 82d8ac3761..9957ef62c3 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -797,7 +797,7 @@ def __call__( return indices -class ConvertToMultiChannelBasedOnBratsClasses(Transform): +class ConvertToMultiChannelBasedOnBratsClasses(TorchTransform, NumpyTransform): """ Convert labels to multi channels based on brats18 classes: label 1 is the necrotic and non-enhancing tumor core @@ -807,19 +807,21 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): and ET (Enhancing tumor). """ - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: DataObjects.Images) -> DataObjects.Images: # if img has channel dim, squeeze it if img.ndim == 4 and img.shape[0] == 1: - img = np.squeeze(img, axis=0) + img = img.squeeze(0) result = [] # merge labels 1 (tumor non-enh) and 4 (tumor enh) to TC - result.append(np.logical_or(img == 1, img == 4)) + result.append((img == 1) | (img == 4)) # | is np.logical_or # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT - result.append(np.logical_or(np.logical_or(img == 1, img == 4), img == 2)) + result.append((img == 1) | (img == 4) | (img == 2)) # label 4 is ET result.append(img == 4) - return np.stack(result, axis=0) + if isinstance(img, np.ndarray): + return np.stack(result, axis=0) + return torch.stack(result, dim=0) class AddExtremePointsChannel(Randomizable, TorchTransform): diff --git a/tests/test_convert_to_multi_channel.py b/tests/test_convert_to_multi_channel.py index 2f7a38e6e4..a350f8fdd8 100644 --- a/tests/test_convert_to_multi_channel.py +++ b/tests/test_convert_to_multi_channel.py @@ -12,31 +12,49 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ConvertToMultiChannelBasedOnBratsClasses +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]]), - np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p(np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]])), + np.array( + [ + [[0, 1, 0], [1, 0, 1], [0, 1, 1]], + [[0, 1, 1], [1, 1, 1], [0, 1, 1]], + [[0, 0, 0], [0, 0, 1], [0, 0, 1]], + ] + ), + ] + ) -TEST_CASE_2 = [ - np.array([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]]), - np.array( + TESTS.append( [ - [[[0, 1], [1, 0]], [[0, 1], [1, 1]]], - [[[0, 1], [1, 1]], [[1, 1], [1, 1]]], - [[[0, 0], [0, 0]], [[0, 1], [1, 1]]], + p(np.array([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]])), + np.array( + [ + [[[0, 1], [1, 0]], [[0, 1], [1, 1]]], + [[[0, 1], [1, 1]], [[1, 1], [1, 1]]], + [[[0, 0], [0, 0]], [[0, 1], [1, 1]]], + ] + ), ] - ), -] + ) class TestConvertToMultiChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_type_shape(self, data, expected_result): result = ConvertToMultiChannelBasedOnBratsClasses()(data) + self.assertEqual(type(result), type(data)) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, data.device) + result = result.cpu().numpy() np.testing.assert_equal(result, expected_result) self.assertEqual(f"{result.dtype}", "bool") diff --git a/tests/test_convert_to_multi_channeld.py b/tests/test_convert_to_multi_channeld.py index 945e07e1cd..0fb48dde34 100644 --- a/tests/test_convert_to_multi_channeld.py +++ b/tests/test_convert_to_multi_channeld.py @@ -12,22 +12,39 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import ConvertToMultiChannelBasedOnBratsClassesd - -TEST_CASE = [ - {"keys": "label"}, - {"label": np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]])}, - np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": "label"}, + {"label": p(np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]]))}, + np.array( + [ + [[0, 1, 0], [1, 0, 1], [0, 1, 1]], + [[0, 1, 1], [1, 1, 1], [0, 1, 1]], + [[0, 0, 0], [0, 0, 1], [0, 0, 1]], + ] + ), + ] + ) class TestConvertToMultiChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE]) + @parameterized.expand(TESTS) def test_type_shape(self, keys, data, expected_result): - result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data) - np.testing.assert_equal(result["label"], expected_result) + result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data)["label"] + self.assertEqual(type(result), type(data["label"])) + if isinstance(result, torch.Tensor): + self.assertEqual(result.device, data["label"].device) + result = result.cpu().numpy() + np.testing.assert_equal(result, expected_result) + self.assertEqual(f"{result.dtype}", "bool") if __name__ == "__main__": From f0419d0ba5a8f25cfce47f07401703225fcf7e6a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 28 Jul 2021 12:52:48 +0100 Subject: [PATCH 157/176] PadCollation Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/batch.py | 86 +++++++++++++----------- tests/test_pad_collation.py | 104 ++++++++++++++++++------------ 2 files changed, 109 insertions(+), 81 deletions(-) diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index 17346a42de..078084a62d 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -14,16 +14,15 @@ """ from copy import deepcopy -from typing import Any, Union +from typing import Any, Sequence, Union import numpy as np import torch from monai.data.utils import list_data_collate -from monai.transforms.compose import Compose from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad from monai.transforms.inverse import InvertibleTransform -from monai.transforms.utility.array import ToTensor +from monai.transforms.transform import NumpyTransform, TorchTransform from monai.utils.enums import DataObjects, InverseKeys, Method, NumpyPadMode __all__ = [ @@ -39,11 +38,14 @@ def replace_element(to_replace, batch, idx, key_or_idx): batch[idx] = tuple(batch_idx_list) # else, replace else: - batch[idx][key_or_idx] = to_replace + if key_or_idx is not None: + batch[idx][key_or_idx] = to_replace + else: + batch[idx] = to_replace return batch -class PadListDataCollate(InvertibleTransform): +class PadListDataCollate(InvertibleTransform, TorchTransform, NumpyTransform): """ Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of @@ -75,6 +77,35 @@ def __init__( self.mode = mode self.np_kwargs = np_kwargs + def replace_batch_element(self, batch, key_or_idx, is_list_of_dicts): + # calculate max size of each dimension + max_shapes = [] + for elem in batch: + im = elem[key_or_idx] if key_or_idx is not None else elem + if not isinstance(im, (torch.Tensor, np.ndarray)): + return batch + max_shapes.append(im.shape[1:]) + max_shape = np.array(max_shapes).max(axis=0) + # If all same size, skip + if np.all(np.array(max_shapes).min(axis=0) == max_shape): + return batch + + # Use `SpatialPad` to match sizes + # Default params are central padding, padding with 0's + padder = SpatialPad(spatial_size=max_shape, method=self.method, mode=self.mode, **self.np_kwargs) + + for idx, elem in enumerate(batch): + im = elem[key_or_idx] if key_or_idx is not None else elem + orig_size = im.shape[1:] + padded = padder(im) + batch = replace_element(padded, batch, idx, key_or_idx) + + # If we have a dictionary of data, append to list + if is_list_of_dicts: + self.push_transform(batch[idx], key_or_idx, orig_size=orig_size) + + return batch + def __call__(self, batch: Any): """ Args: @@ -82,40 +113,17 @@ def __call__(self, batch: Any): """ # data is either list of dicts or list of lists is_list_of_dicts = isinstance(batch[0], dict) - # loop over items inside of each element in a batch - for key_or_idx in batch[0].keys() if is_list_of_dicts else range(len(batch[0])): - # calculate max size of each dimension - max_shapes = [] - for elem in batch: - if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)): - break - max_shapes.append(elem[key_or_idx].shape[1:]) - # len > 0 if objects were arrays, else skip as no padding to be done - if not max_shapes: - continue - max_shape = np.array(max_shapes).max(axis=0) - # If all same size, skip - if np.all(np.array(max_shapes).min(axis=0) == max_shape): - continue - # Do we need to convert output to Tensor? - output_to_tensor = isinstance(batch[0][key_or_idx], torch.Tensor) - - # Use `SpatialPadd` or `SpatialPad` to match sizes - # Default params are central padding, padding with 0's - # If input is dictionary, use the dictionary version so that the transformation is recorded - - padder = SpatialPad(spatial_size=max_shape, method=self.method, mode=self.mode, **self.np_kwargs) - transform = padder if not output_to_tensor else Compose([padder, ToTensor()]) - - for idx, batch_i in enumerate(batch): - im = batch_i[key_or_idx] - orig_size = im.shape[1:] - padded = transform(batch_i[key_or_idx]) - batch = replace_element(padded, batch, idx, key_or_idx) - - # If we have a dictionary of data, append to list - if is_list_of_dicts: - self.push_transform(batch[idx], key_or_idx, orig_size=orig_size) + # if data is a list of dictionaries, loop over keys + if is_list_of_dicts: + for key in batch[0].keys(): + batch = self.replace_batch_element(batch, key, is_list_of_dicts) + # elif is a list of lists/tuples + elif isinstance(batch[0], Sequence): + for idx in range(len(batch[0])): + batch = self.replace_batch_element(batch, idx, is_list_of_dicts) + # elif there's only one element per batch, either a torcn.Tensor or np.ndarray + elif isinstance(batch[0], (torch.Tensor, np.ndarray)): + batch = self.replace_batch_element(batch, None, is_list_of_dicts) # After padding, use default list collator return list_data_collate(batch) diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index a8c544558f..d60b280317 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random import unittest from typing import List, Tuple @@ -18,6 +17,7 @@ from parameterized import parameterized from monai.data import CacheDataset, DataLoader +from monai.data.dataset import Dataset from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms import ( PadListDataCollate, @@ -30,58 +30,72 @@ RandZoom, RandZoomd, ) -from monai.utils import set_determinism +from tests.utils import TEST_NDARRAYS TESTS: List[Tuple] = [] -for pad_collate in [ - lambda x: pad_list_data_collate(batch=x, method="end", mode="constant", constant_values=1), - PadListDataCollate(method="end", mode="constant", constant_values=1), -]: - TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) - TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) - TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) - TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2))) - - TESTS.append((list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) - TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) - TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) - TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2))) - - -class _Dataset(torch.utils.data.Dataset): - def __init__(self, images, labels, transforms): - self.images = images - self.labels = labels - self.transforms = transforms - - def __len__(self): - return len(self.images) - +for p in TEST_NDARRAYS: + for include_label in (True, False): + for pad_collate in [ + lambda x: pad_list_data_collate(batch=x, method="end", mode="constant", constant_values=1), + PadListDataCollate(method="end", mode="constant", constant_values=1), + ]: + TESTS.append( + (dict, p, include_label, pad_collate, RandSpatialCropd("im", roi_size=[8, 7], random_size=True)) + ) + TESTS.append( + (dict, p, include_label, pad_collate, RandRotated("im", prob=1, range_x=np.pi, keep_size=False)) + ) + TESTS.append( + ( + dict, + p, + include_label, + pad_collate, + RandZoomd("im", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False), + ) + ) + TESTS.append((dict, p, include_label, pad_collate, RandRotate90d("im", prob=1, max_k=2))) + + TESTS.append((list, p, include_label, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) + TESTS.append((list, p, include_label, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False))) + TESTS.append( + (list, p, include_label, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)) + ) + TESTS.append((list, p, include_label, pad_collate, RandRotate90(prob=1, max_k=2))) + + +class TupleDataset(Dataset): def __getitem__(self, index): - return self.transforms(self.images[index]), self.labels[index] + return self.transform(self.data[index][0]), self.data[index][1] class TestPadCollation(unittest.TestCase): - def setUp(self) -> None: - set_determinism(seed=0) + @staticmethod + def get_data(t_type, im_type, include_label): # image is non square to throw rotation errors - im = np.arange(0, 10 * 9).reshape(1, 10, 9) + im = im_type(np.arange(0, 10 * 9).reshape(1, 10, 9)) num_elements = 20 - self.dict_data = [{"image": im} for _ in range(num_elements)] - self.list_data = [im for _ in range(num_elements)] - self.list_labels = [random.randint(0, 1) for _ in range(num_elements)] - - def tearDown(self) -> None: - set_determinism(None) + out = [] + for _ in range(num_elements): + label = np.random.randint(0, 1) + if t_type is dict: + out.append({"im": im, "label": label} if include_label else {"im": im}) + else: + out.append((im, label) if include_label else im) + return out @parameterized.expand(TESTS) - def test_pad_collation(self, t_type, collate_method, transform): + def test_pad_collation(self, t_type, im_type, include_label, collate_method, transform): + + input_data = self.get_data(t_type, im_type, include_label) - if t_type == dict: - dataset = CacheDataset(self.dict_data, transform, progress=False) + if t_type is dict: + dataset = CacheDataset(input_data, transform, progress=False) + elif isinstance(input_data[0], tuple): + dataset = TupleDataset(input_data, transform) else: - dataset = _Dataset(self.list_data, self.list_labels, transform) + dataset = Dataset(input_data, transform) # Default collation should raise an error loader_fail = DataLoader(dataset, batch_size=10) @@ -93,8 +107,14 @@ def test_pad_collation(self, t_type, collate_method, transform): loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method) # check collation in forward direction for data in loader: - if t_type == dict: - decollated_data = decollate_batch(data) + d = data["im"] if isinstance(data, dict) else data + i = input_data[0]["im"] if isinstance(data, dict) else input_data[0] + if isinstance(i, torch.Tensor): + self.assertEqual(d.device, i.device) + + decollated_data = decollate_batch(data) + # if a dictionary, do the inverse + if t_type is dict: for d in decollated_data: PadListDataCollate.inverse(d) From ed6903ae5bdfcf2c653d5f00e1a65023afc88cd5 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 28 Jul 2021 14:38:09 +0100 Subject: [PATCH 158/176] mypy Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index ae3c0f81f0..8c4b73363b 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1501,9 +1501,9 @@ def _set_spike(self, k: DataObjects.Images, idx: Tuple, val: Union[Sequence[floa else: k[idx] = val elif len(k.shape) == 4 and len(idx) == 3: - k[:, idx[0], idx[1], idx[2]] = val + k[:, idx[0], idx[1], idx[2]] = val # type: ignore elif len(k.shape) == 3 and len(idx) == 2: - k[:, idx[0], idx[1]] = val + k[:, idx[0], idx[1]] = val # type: ignore def _shift_fourier(self, x: DataObjects.Images, n_dims: int, data_type: type) -> DataObjects.Images: """ From 9a7cf963c75c08653a4a0c724a3ec2c16fcdc791 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 28 Jul 2021 14:48:50 +0100 Subject: [PATCH 159/176] fix 1d SpatialCrop Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index cfe6125089..a06e997775 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -369,8 +369,11 @@ def __init__( roi_start = torch.as_tensor(roi_start, dtype=torch.int16) roi_start = torch.maximum(roi_start, torch.tensor(0, device=roi_start.device)) roi_end = torch.maximum(torch.as_tensor(roi_end, dtype=torch.int16), roi_start) - # convert to slices - self.slices = [slice(s, e) for s, e in zip(roi_start, roi_end)] + # convert to slices (accounting for 1d) + if roi_start.numel() == 1: + self.slices = [slice(roi_start, roi_end)] + else: + self.slices = [slice(s, e) for s, e in zip(roi_start, roi_end)] def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ From 9db32569b42fcbc797e280253ea8ceec9972c312 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 28 Jul 2021 15:29:25 +0100 Subject: [PATCH 160/176] fix randbiasfield Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 10 ++++++---- tests/test_random_bias_field.py | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 8c4b73363b..cc0e8b214e 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -498,7 +498,6 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: if not self._do_transform: return img num_channels, *spatial_shape = img.shape - _bias_fields: DataObjects.Images _bias_fields = np.stack( [ self._generate_random_field(spatial_shape=spatial_shape, degree=self.degree, coeff=self._coeff) @@ -506,10 +505,13 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: ], axis=0, ) - _bias_fields, *_ = convert_data_type( - _bias_fields, type(img), dtype=self.dtype, device=img.device if isinstance(img, torch.Tensor) else None + + _bias_fields_exp: DataObjects.Images + _bias_fields_exp = np.exp(_bias_fields) + _bias_fields_exp, *_ = convert_data_type( + _bias_fields_exp, type(img), dtype=self.dtype, device=img.device if isinstance(img, torch.Tensor) else None ) - out = img * _bias_fields + out = img * _bias_fields_exp out, *_ = convert_data_type(out, dtype=self.dtype) return out diff --git a/tests/test_random_bias_field.py b/tests/test_random_bias_field.py index 995db1e6e6..8df21a3a40 100644 --- a/tests/test_random_bias_field.py +++ b/tests/test_random_bias_field.py @@ -49,14 +49,14 @@ def test_output_shape(self, class_args, img_shape): def test_zero_range(self, class_args, img_shape): for p in TEST_NDARRAYS: bias_field = RandBiasField(**class_args) - img = p(np.random.rand(*img_shape)) + img = p(np.ones(img_shape)) output = bias_field(img) self.assertEqual(type(img), type(output)) if isinstance(output, torch.Tensor): self.assertEqual(output.device, img.device) output = output.cpu().numpy() - np.testing.assert_equal(output, np.zeros(img_shape)) + np.testing.assert_allclose(output, np.ones(img_shape), rtol=1e-3) @parameterized.expand([TEST_CASES_2D_ONES]) def test_one_range_input(self, class_args, expected): @@ -68,7 +68,7 @@ def test_one_range_input(self, class_args, expected): if isinstance(output, torch.Tensor): self.assertEqual(output.device, img.device) output = output.cpu().numpy() - np.testing.assert_equal(output, expected.astype(bias_field.dtype)) + np.testing.assert_allclose(output, expected.astype(bias_field.dtype), rtol=1e-3) def test_zero_prob(self): for p in TEST_NDARRAYS: From ecbb200f72370b45bc7d0baf9e868e1d6c09ef49 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 28 Jul 2021 15:52:40 +0100 Subject: [PATCH 161/176] fix KSpaceSpikeNoise Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/intensity/array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index cc0e8b214e..0d53cbb8c5 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1447,6 +1447,7 @@ def __call__(self, img: DataObjects.Images) -> DataObjects.Images: n_dims = len(img.shape[1:]) data_type = type(img) + lib = np if isinstance(img, np.ndarray) else torch # FT k = self._shift_fourier(img, n_dims, data_type) From c3323d5e6a7b0e2f54ade49cdd7c67046c2c906c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 28 Jul 2021 15:54:51 +0100 Subject: [PATCH 162/176] fix RandBiasFieldd Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_random_bias_fieldd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_random_bias_fieldd.py b/tests/test_random_bias_fieldd.py index 8df3242442..f0df49231f 100644 --- a/tests/test_random_bias_fieldd.py +++ b/tests/test_random_bias_fieldd.py @@ -54,7 +54,7 @@ def test_zero_range(self, class_args, img_shape): self.assertEqual(output[key].device, img.device) output = output[key].cpu().numpy() - np.testing.assert_equal(output[key], np.zeros(img_shape)) + np.testing.assert_allclose(output[key], np.ones(img_shape), rtol=1e-3) @parameterized.expand([TEST_CASES_2D_ONES]) def test_one_range_input(self, class_args, expected): From e645a3f16d87dd29356600b1c0fc7d7ec5db9c8e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 28 Jul 2021 16:36:50 +0100 Subject: [PATCH 163/176] fix generate_pos_neg_label_crop_centers Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils.py | 9 +++-- ...est_generate_pos_neg_label_crop_centers.py | 37 +++++++++++-------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 5365c3f8a7..f7e09de2b5 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -481,10 +481,12 @@ def generate_pos_neg_label_crop_centers( rand_state = np.random.random.__self__ # type: ignore centers = [] - if fg_indices.size == 0 and bg_indices.size == 0: + fg_indices = np.asarray(fg_indices) if isinstance(fg_indices, Sequence) else fg_indices + bg_indices = np.asarray(bg_indices) if isinstance(bg_indices, Sequence) else bg_indices + if len(fg_indices) == 0 and len(bg_indices) == 0: raise ValueError("No sampling location available.") - if fg_indices.size == 0 or bg_indices.size == 0: + if len(fg_indices) == 0 or len(bg_indices) == 0: warnings.warn( f"N foreground {len(fg_indices)}, N background {len(bg_indices)}," "unable to generate class balanced samples." @@ -495,8 +497,7 @@ def generate_pos_neg_label_crop_centers( indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices random_int = rand_state.randint(len(indices_to_use)) idx = indices_to_use[random_int] - idx = idx.cpu() if isinstance(idx, torch.Tensor) else idx - center = np.unravel_index(idx, label_spatial_shape) + center = _unravel_index(idx, label_spatial_shape) # shift center to range of valid centers center_ori = list(center) centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py index 40181aa9ea..0335a94d97 100644 --- a/tests/test_generate_pos_neg_label_crop_centers.py +++ b/tests/test_generate_pos_neg_label_crop_centers.py @@ -15,25 +15,30 @@ from parameterized import parameterized from monai.transforms import generate_pos_neg_label_crop_centers - -TEST_CASE_1 = [ - { - "spatial_size": [2, 2, 2], - "num_samples": 2, - "pos_ratio": 1.0, - "label_spatial_shape": [3, 3, 3], - "fg_indices": [1, 9, 18], - "bg_indices": [3, 12, 21], - "rand_state": np.random.RandomState(), - }, - list, - 2, - 3, -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS + (None,): + TESTS.append( + [ + { + "spatial_size": [2, 2, 2], + "num_samples": 2, + "pos_ratio": 1.0, + "label_spatial_shape": [3, 3, 3], + "fg_indices": [1, 9, 18] if p is None else p([1, 9, 18]), + "bg_indices": [3, 12, 21] if p is None else p([3, 12, 21]), + "rand_state": np.random.RandomState(), + }, + list, + 2, + 3, + ] + ) class TestGeneratePosNegLabelCropCenters(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected_type, expected_count, expected_shape): result = generate_pos_neg_label_crop_centers(**input_data) self.assertIsInstance(result, expected_type) From 293f1962c3b7421a7cdbf5912c9083c8c28b9739 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 29 Jul 2021 12:16:27 +0100 Subject: [PATCH 164/176] fix generate_spatial_bounding_box Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index f7e09de2b5..d20b0e29c5 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -739,10 +739,20 @@ def generate_spatial_bounding_box( box_end = [0] * ndim for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)): - dt = data if len(ax) == 0 else data.any(ax[0]) + dt = data + if len(ax) != 0: + if isinstance(dt, np.ndarray): + dt = dt.any(ax) + # pytorch can't handle multiple dimensions to `any` so loop across them + # this works because the dimensions will be reverse sorted. + else: + for i in ax: + dt = dt.any(i) + if not dt.any(): # if no foreground, return all zero bounding box coords return [0] * ndim, [0] * ndim + dt = dt if isinstance(dt, np.ndarray) else dt.int() rev_dt = dt[::-1] if isinstance(dt, np.ndarray) else dt.flip(0) From d03911b9e5ad60e749c11be3780fcbd0d4cfc9b9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 29 Jul 2021 11:27:46 +0000 Subject: [PATCH 165/176] fix resize Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 33 +++++++++++++++++++------------ tests/test_resized.py | 2 +- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 79c95f8e11..9eda4ced83 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -14,8 +14,8 @@ """ import warnings -from math import ceil from copy import deepcopy +from math import ceil from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np @@ -404,20 +404,27 @@ def __call__( """ img_t: torch.Tensor img_t, orig_type, orig_device = convert_data_type(img, torch.Tensor, dtype=float) # type: ignore - input_ndim = img_t.ndim - 1 # spatial ndim - output_ndim = len(self.spatial_size) - if output_ndim > input_ndim: - input_shape = ensure_tuple_size(img_t.shape, output_ndim + 1, 1) - img_t = img_t.reshape(input_shape) - elif output_ndim < input_ndim: - raise ValueError( - "len(spatial_size) must be greater or equal to img spatial dimensions, " - f"got spatial_size={output_ndim} img={input_ndim}." - ) - spatial_size = fall_back_tuple(self.spatial_size, img_t.shape[1:]) + if self.size_mode == "all": + input_ndim = img_t.ndim - 1 # spatial ndim + output_ndim = len(ensure_tuple(self.spatial_size)) + if output_ndim > input_ndim: + input_shape = ensure_tuple_size(img_t.shape, output_ndim + 1, 1) + img_t = img_t.reshape(input_shape) + elif output_ndim < input_ndim: + raise ValueError( + "len(spatial_size) must be greater or equal to img spatial dimensions, " + f"got spatial_size={output_ndim} img={input_ndim}." + ) + spatial_size_ = fall_back_tuple(self.spatial_size, img_t.shape[1:]) + else: + img_size = img.shape[1:] + if not isinstance(self.spatial_size, int): + raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") + scale = self.spatial_size / max(img_size) + spatial_size_ = tuple(ceil(s * scale) for s in img_size) resized = torch.nn.functional.interpolate( input=img_t.unsqueeze(0), - size=spatial_size, + size=spatial_size_, mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, align_corners=self.align_corners if align_corners is None else align_corners, ) diff --git a/tests/test_resized.py b/tests/test_resized.py index a88715357a..f9b605c919 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -59,7 +59,7 @@ def test_correct_results(self, in_type, spatial_size, mode): out = out.cpu() np.testing.assert_allclose(out, expected, atol=0.9) - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_longest_shape(self, input_param, expected_shape): input_data = { "img": np.random.randint(0, 2, size=[3, 4, 7, 10]), From 40fa2889ee6abb17af7d53b6c4e84edef2b5d508 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 29 Jul 2021 12:20:08 +0000 Subject: [PATCH 166/176] separate one_hot fns for np and torch (allows for torchscript) Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/losses/dice.py | 6 +-- monai/losses/focal_loss.py | 4 +- monai/losses/tversky.py | 4 +- monai/networks/__init__.py | 2 + monai/networks/utils.py | 67 +++++++++++++++++++----------- tests/test_focal_loss.py | 8 ++-- tests/test_seg_loss_integration.py | 4 +- 7 files changed, 57 insertions(+), 38 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index e8e6039bc7..cff107acc2 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -20,7 +20,7 @@ from monai.losses.focal_loss import FocalLoss from monai.losses.spatial_mask import MaskedLoss -from monai.networks import one_hot +from monai.networks.utils import one_hot_torch from monai.utils import LossReduction, Weight, look_up_option from monai.utils.enums import DataObjects @@ -130,7 +130,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - target = one_hot(target, num_classes=n_pred_ch) # type: ignore + target = one_hot_torch(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: @@ -306,7 +306,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - target = one_hot(target, num_classes=n_pred_ch) # type: ignore + target = one_hot_torch(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index f4c361168e..7f11ecaca0 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from torch.nn.modules.loss import _Loss -from monai.networks import one_hot +from monai.networks.utils import one_hot_torch from monai.utils import LossReduction @@ -96,7 +96,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - target = one_hot(target, num_classes=n_pred_ch) # type: ignore + target = one_hot_torch(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 8170920b47..355a662901 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -15,7 +15,7 @@ import torch from torch.nn.modules.loss import _Loss -from monai.networks import one_hot +from monai.networks.utils import one_hot_torch from monai.utils import LossReduction @@ -120,7 +120,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - target = one_hot(target, num_classes=n_pred_ch) # type: ignore + target = one_hot_torch(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 3c347dad22..b74805ef88 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -16,6 +16,8 @@ normal_init, normalize_transform, one_hot, + one_hot_np, + one_hot_torch, pixelshuffle, predict_segmentation, slice_channels, diff --git a/monai/networks/utils.py b/monai/networks/utils.py index dac2f6039d..62bbc7bbe0 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -27,6 +27,8 @@ __all__ = [ "one_hot", + "one_hot_np", + "one_hot_torch", "slice_channels", "predict_segmentation", "normalize_transform", @@ -40,6 +42,44 @@ ] +def one_hot_np(labels: np.ndarray, num_classes: int, dtype: DtypeLike = np.float32, dim: int = 1) -> np.ndarray: + """ + Numpy implementation of `one_hot`. + See also: :py:meth:`monai.networks.utils.one_hot`. + """ + label: np.ndarray + label = np.eye(num_classes)[labels.astype(np.longlong)] # adds one hot to end + label = label.astype(dtype) + label = label.squeeze(dim) # remove singleton + label = np.moveaxis(label, -1, dim) # move one hot dim to desired index + return label + + +def one_hot_torch( + labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1 +) -> torch.Tensor: + """ + Torch implementation of `one_hot`. + See also: :py:meth:`monai.networks.utils.one_hot`. + """ + # if `dim` is bigger, add singleton dim at the end + if labels.ndim < dim + 1: + shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) + labels = torch.reshape(labels, shape) + + sh = list(labels.shape) + + if sh[dim] != 1: + raise AssertionError("labels should have a channel with length equal to one.") + + sh[dim] = num_classes + + o = torch.zeros(size=sh, dtype=dtype, device=labels.device) + labels = o.scatter_(dim=dim, index=labels.long(), value=1) + + return labels + + def one_hot( labels: DataObjects.Images, num_classes: int, dtype: Union[DtypeLike, torch.dtype] = torch.float, dim: int = 1 ) -> DataObjects.Images: @@ -74,34 +114,11 @@ def one_hot( a = torch.randint(0, 2, size=(2, 1, 2, 2, 2)) out = one_hot(a, num_classes=2, dim=1) print(out.shape) # torch.Size([2, 2, 2, 2, 2]) - """ dtype = dtype_convert(dtype, type(labels)) - - # if `dim` is bigger, add singleton dim at the end - if labels.ndim < dim + 1: - shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) - labels = labels.reshape(shape) - - sh = list(labels.shape) - - if sh[dim] != 1: - raise AssertionError("labels should have a channel with length equal to one.") - if isinstance(labels, np.ndarray): - label_np: np.ndarray - label_np = np.eye(num_classes)[labels.astype(np.longlong)] # adds one hot to end - label_np = label_np.astype(dtype) # type: ignore - label_np = label_np.squeeze(dim) # remove singleton - label_np = np.moveaxis(label_np, -1, dim) # move one hot dim to desired index - return label_np - - sh[dim] = num_classes - - o = torch.zeros(size=sh, dtype=dtype, device=labels.device) # type: ignore - labels = o.scatter_(dim=dim, index=labels.long(), value=1) - - return labels + return one_hot_np(labels, num_classes, dtype, dim) # type: ignore + return one_hot_torch(labels, num_classes, dtype, dim) # type: ignore def slice_channels(tensor: torch.Tensor, *slicevals: Optional[int]) -> torch.Tensor: diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 1314fe3841..5c102989f3 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from monai.losses import FocalLoss -from monai.networks import one_hot +from monai.networks.utils import one_hot_torch from tests.utils import SkipIfBeforePyTorchVersion, test_script_save @@ -60,7 +60,7 @@ def test_consistency_with_cross_entropy_2d_onehot_label(self): x = x.cuda() l = l.cuda() output0 = focal_loss(x, l) - output1 = ce(x, one_hot(l, num_classes=class_num)) + output1 = ce(x, one_hot_torch(l, num_classes=class_num)) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: @@ -84,7 +84,7 @@ def test_consistency_with_cross_entropy_classification(self): x = x.cuda() l = l.cuda() output0 = focal_loss(x, l) - output1 = ce(x, one_hot(l, num_classes=class_num)) + output1 = ce(x, one_hot_torch(l, num_classes=class_num)) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: @@ -108,7 +108,7 @@ def test_consistency_with_cross_entropy_classification_01(self): x = x.cuda() l = l.cuda() output0 = focal_loss(x, l) - output1 = ce(x, one_hot(l, num_classes=class_num)) + output1 = ce(x, one_hot_torch(l, num_classes=class_num)) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: diff --git a/tests/test_seg_loss_integration.py b/tests/test_seg_loss_integration.py index d2f991f160..807597cdc1 100644 --- a/tests/test_seg_loss_integration.py +++ b/tests/test_seg_loss_integration.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.losses import DiceLoss, FocalLoss, GeneralizedDiceLoss, TverskyLoss -from monai.networks import one_hot +from monai.networks.utils import one_hot_torch TEST_CASES = [ [DiceLoss, {"to_onehot_y": True, "squared_pred": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, {}], @@ -86,7 +86,7 @@ def test_convergence(self, loss_type, loss_args, forward_args): num_classes = 2 num_voxels = 3 * 4 * 4 - target_onehot = one_hot(target_seg, num_classes=num_classes) + target_onehot = one_hot_torch(target_seg, num_classes=num_classes) # define a one layer model class OnelayerNet(nn.Module): From fd1af06507c60d57efec4d45af6cf5c14ede48a9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 29 Jul 2021 12:48:15 +0000 Subject: [PATCH 167/176] fix various tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_invertd.py | 1 - tests/test_resize.py | 15 +++++++-------- tests/test_resized.py | 27 ++++++++++++++++++++++----- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index 820ef2e087..5a74e90c99 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -84,7 +84,6 @@ def test_invert(self): nearest_interp=True, to_tensor=[True, False, False], device="cpu", - num_workers=num_workers, ) # execute 1 epoch diff --git a/tests/test_resize.py b/tests/test_resize.py index eaa84e0592..7b1488b0e8 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -21,17 +21,16 @@ from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D TESTS: List[Tuple] = [] +TEST_LONGEST: List[Tuple] = [] for p in TEST_NDARRAYS: TESTS.append((p, (32, -1), "area")) TESTS.append((p, (32, 32), "area")) TESTS.append((p, (32, 32, 32), "trilinear")) TESTS.append((p, (256, 256), "bilinear")) -TEST_CASE_0 = [{"spatial_size": 15}, (6, 11, 15)] - -TEST_CASE_1 = [{"spatial_size": 15, "mode": "area"}, (6, 11, 15)] - -TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)] + TEST_LONGEST.append((p, {"spatial_size": 15}, (6, 11, 15))) + TEST_LONGEST.append((p, {"spatial_size": 15, "mode": "area"}, (6, 11, 15))) + TEST_LONGEST.append((p, {"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6))) class TestResize(NumpyImageTestCase2D): @@ -68,9 +67,9 @@ def test_correct_results(self, in_type, spatial_size, mode): out = out.cpu() np.testing.assert_allclose(out, expected, atol=0.9) - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) - def test_longest_shape(self, input_param, expected_shape): - input_data = np.random.randint(0, 2, size=[3, 4, 7, 10]) + @parameterized.expand(TEST_LONGEST) + def test_longest_shape(self, im_type, input_param, expected_shape): + input_data = im_type(np.random.randint(0, 2, size=[3, 4, 7, 10])) input_param["size_mode"] = "longest" result = Resize(**input_param)(input_data) np.testing.assert_allclose(result.shape[1:], expected_shape) diff --git a/tests/test_resized.py b/tests/test_resized.py index f9b605c919..c0bd9a49a7 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -21,12 +21,29 @@ from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D TESTS: List[Tuple] = [] +TEST_LONGEST: List[Tuple] = [] for p in TEST_NDARRAYS: TESTS.append((p, (32, -1), "area")) TESTS.append((p, (64, 64), "area")) TESTS.append((p, (32, 32, 32), "area")) TESTS.append((p, (256, 256), "bilinear")) + TEST_LONGEST.append((p, {"keys": "img", "spatial_size": 15}, (6, 11, 15))) + TEST_LONGEST.append((p, {"keys": "img", "spatial_size": 15, "mode": "area"}, (6, 11, 15))) + TEST_LONGEST.append((p, {"keys": "img", "spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6))) + TEST_LONGEST.append( + ( + p, + { + "keys": ["img", "label"], + "spatial_size": 6, + "mode": ["trilinear", "nearest"], + "align_corners": [True, None], + }, + (3, 5, 6), + ) + ) + class TestResized(NumpyImageTestCase2D): def test_invalid_inputs(self): @@ -40,7 +57,7 @@ def test_invalid_inputs(self): @parameterized.expand(TESTS) def test_correct_results(self, in_type, spatial_size, mode): - resize = Resized("img", spatial_size, mode) + resize = Resized("img", spatial_size, mode=mode) _order = 0 if mode.endswith("linear"): _order = 1 @@ -59,11 +76,11 @@ def test_correct_results(self, in_type, spatial_size, mode): out = out.cpu() np.testing.assert_allclose(out, expected, atol=0.9) - @parameterized.expand(TESTS) - def test_longest_shape(self, input_param, expected_shape): + @parameterized.expand(TEST_LONGEST) + def test_longest_shape(self, in_type, input_param, expected_shape): input_data = { - "img": np.random.randint(0, 2, size=[3, 4, 7, 10]), - "label": np.random.randint(0, 2, size=[3, 4, 7, 10]), + "img": in_type(np.random.randint(0, 2, size=[3, 4, 7, 10])), + "label": in_type(np.random.randint(0, 2, size=[3, 4, 7, 10])), } input_param["size_mode"] = "longest" rescaler = Resized(**input_param) From be9ce910aa689957fba0dc775919d5c448853400 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 29 Jul 2021 13:06:23 +0000 Subject: [PATCH 168/176] fix one_hot Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/utils.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 62bbc7bbe0..cdfc60b276 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -42,11 +42,26 @@ ] +def _one_hot_pre_process(labels, dim): + # if `dim` is bigger, add singleton dim at the end + if labels.ndim < dim + 1: + shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) + labels = labels.reshape(shape) + + sh = list(labels.shape) + + if sh[dim] != 1: + raise AssertionError("labels should have a channel with length equal to one.") + return labels, sh + + def one_hot_np(labels: np.ndarray, num_classes: int, dtype: DtypeLike = np.float32, dim: int = 1) -> np.ndarray: """ Numpy implementation of `one_hot`. See also: :py:meth:`monai.networks.utils.one_hot`. """ + labels, _ = _one_hot_pre_process(labels, dim) + label: np.ndarray label = np.eye(num_classes)[labels.astype(np.longlong)] # adds one hot to end label = label.astype(dtype) @@ -62,15 +77,7 @@ def one_hot_torch( Torch implementation of `one_hot`. See also: :py:meth:`monai.networks.utils.one_hot`. """ - # if `dim` is bigger, add singleton dim at the end - if labels.ndim < dim + 1: - shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) - labels = torch.reshape(labels, shape) - - sh = list(labels.shape) - - if sh[dim] != 1: - raise AssertionError("labels should have a channel with length equal to one.") + labels, sh = _one_hot_pre_process(labels, dim) sh[dim] = num_classes From 6401533c595acaa90a5839a634ff2ce93cb9593d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 29 Jul 2021 13:21:22 +0000 Subject: [PATCH 169/176] fix one_hot 2 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index cdfc60b276..cd1485fd1d 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -42,7 +42,7 @@ ] -def _one_hot_pre_process(labels, dim): +def _one_hot_pre_process(labels, dim: int): # if `dim` is bigger, add singleton dim at the end if labels.ndim < dim + 1: shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) From c30398987d61998c7eaa17bc25bad4d43c223d27 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 29 Jul 2021 15:10:32 +0100 Subject: [PATCH 170/176] fix test_invers Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index a06e997775..9c5a8c7fed 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -359,7 +359,7 @@ def __init__( roi_center = torch.as_tensor(roi_center, dtype=torch.int16) roi_size = torch.as_tensor(roi_size, dtype=torch.int16, device=roi_center.device) roi_start = torch.maximum( - roi_center - torch.div(roi_size, 2, rounding_mode="floor"), + roi_center - torch.floor(roi_size / 2), torch.tensor(0, device=roi_center.device), ) roi_end = torch.maximum(roi_start + roi_size, roi_start) @@ -371,9 +371,9 @@ def __init__( roi_end = torch.maximum(torch.as_tensor(roi_end, dtype=torch.int16), roi_start) # convert to slices (accounting for 1d) if roi_start.numel() == 1: - self.slices = [slice(roi_start, roi_end)] + self.slices = [slice(int(roi_start.item()), int(roi_end.item()))] else: - self.slices = [slice(s, e) for s, e in zip(roi_start, roi_end)] + self.slices = [slice(int(s.item()), int(e.item())) for s, e in zip(roi_start, roi_end)] def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ From 34776d819778533602fe5d8fda5e16a901552a0a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 30 Jul 2021 10:23:35 +0100 Subject: [PATCH 171/176] fix pt1.6 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 35 ++++++++++++++++++++++--------- monai/transforms/utils.py | 6 +++++- monai/utils/__init__.py | 1 + monai/utils/misc.py | 8 ++++++- tests/utils.py | 5 ++--- 5 files changed, 40 insertions(+), 15 deletions(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 9c5a8c7fed..da1278be15 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -36,7 +36,7 @@ ) from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, look_up_option from monai.utils.enums import DataObjects -from monai.utils.misc import convert_data_type +from monai.utils.misc import convert_data_type, is_module_ver_at_least __all__ = [ "Pad", @@ -332,6 +332,21 @@ class SpatialCrop(TorchTransform, NumpyTransform): - the start and end coordinates of the ROI """ + @staticmethod + def _maximum(a, b): + if isinstance(a, np.ndarray): + return np.maximum(a, b) + # is torch and has torch.maximum (pt>1.6) + if hasattr(torch, "maximum"): + return torch.maximum(a, b) + return np.maximum(a.cpu(), b.cpu()) + + @staticmethod + def _floor_div(a, b): + if is_module_ver_at_least(torch, (1, 7, 0)): + return torch.div(a, b, rounding_mode="floor") + return torch.floor_divide(a, b) + def __init__( self, roi_center: Optional[Union[Sequence[int], DataObjects.Images]] = None, @@ -358,22 +373,22 @@ def __init__( if roi_center is not None and roi_size is not None: roi_center = torch.as_tensor(roi_center, dtype=torch.int16) roi_size = torch.as_tensor(roi_size, dtype=torch.int16, device=roi_center.device) - roi_start = torch.maximum( - roi_center - torch.floor(roi_size / 2), + roi_start_torch = self._maximum( + roi_center - self._floor_div(roi_size, 2), torch.tensor(0, device=roi_center.device), ) - roi_end = torch.maximum(roi_start + roi_size, roi_start) + roi_end_torch = self._maximum(roi_start_torch + roi_size, roi_start_torch) else: if roi_start is None or roi_end is None: raise ValueError("Please specify either roi_center, roi_size or roi_start, roi_end.") - roi_start = torch.as_tensor(roi_start, dtype=torch.int16) - roi_start = torch.maximum(roi_start, torch.tensor(0, device=roi_start.device)) - roi_end = torch.maximum(torch.as_tensor(roi_end, dtype=torch.int16), roi_start) + roi_start_torch = torch.as_tensor(roi_start, dtype=torch.int16) + roi_start_torch = self._maximum(roi_start_torch, torch.tensor(0, device=roi_start_torch.device)) + roi_end_torch = self._maximum(torch.as_tensor(roi_end, dtype=torch.int16), roi_start_torch) # convert to slices (accounting for 1d) - if roi_start.numel() == 1: - self.slices = [slice(int(roi_start.item()), int(roi_end.item()))] + if roi_start_torch.numel() == 1: + self.slices = [slice(int(roi_start_torch.item()), int(roi_end_torch.item()))] else: - self.slices = [slice(int(s.item()), int(e.item())) for s, e in zip(roi_start, roi_end)] + self.slices = [slice(int(s.item()), int(e.item())) for s, e in zip(roi_start_torch, roi_end_torch)] def __call__(self, img: DataObjects.Images) -> DataObjects.Images: """ diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index d20b0e29c5..f66181a3cb 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -36,7 +36,7 @@ optional_import, ) from monai.utils.enums import DataObjects -from monai.utils.misc import convert_data_type +from monai.utils.misc import convert_data_type, is_module_ver_at_least measure, _ = optional_import("skimage.measure", "0.14.2", min_version) cp, has_cp = optional_import("cupy") @@ -727,6 +727,10 @@ def generate_spatial_bounding_box( of image. if None, select foreground on the whole image. margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. """ + # argmax changed for more recent pytorch versions + if isinstance(img, torch.Tensor) and not is_module_ver_at_least(torch, (1, 7, 0)): + img = img.cpu().numpy() + data = img[list(ensure_tuple(channel_indices))] if channel_indices is not None else img data = select_fn(data).any(0) ndim = len(data.shape) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 1c52e8dce6..a1be196874 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -50,6 +50,7 @@ first, get_seed, has_option, + is_module_ver_at_least, is_scalar, is_scalar_tensor, issequenceiterable, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 5dcf71a0eb..c8041e653c 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -24,7 +24,7 @@ from monai.config.type_definitions import DtypeLike from monai.utils.enums import DataObjects -from monai.utils.module import get_torch_version_tuple +from monai.utils.module import get_torch_version_tuple, version_leq __all__ = [ "zip_with", @@ -48,6 +48,7 @@ "copy_to_device", "ImageMetaKey", "convert_data_type", + "is_module_ver_at_least", ] _seed = None @@ -464,3 +465,8 @@ def has_option(obj, keywords: Union[str, Sequence[str]]) -> bool: return False sig = inspect.signature(obj) return all(key in sig.parameters for key in ensure_tuple(keywords)) + + +def is_module_ver_at_least(module, required_version_tuple): + test_ver = ".".join(map(str, required_version_tuple)) + return module.__version__ != test_ver and version_leq(test_ver, module.__version__) diff --git a/tests/utils.py b/tests/utils.py index 3517a5be2a..8a3a4d3c23 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,6 +33,7 @@ from monai.config.deviceconfig import USE_COMPILED from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism +from monai.utils.misc import is_module_ver_at_least from monai.utils.module import version_leq nib, _ = optional_import("nibabel") @@ -112,9 +113,7 @@ class SkipIfBeforePyTorchVersion: with PyTorch versions older than that given.""" def __init__(self, pytorch_version_tuple): - self.min_version = pytorch_version_tuple - test_ver = ".".join(map(str, self.min_version)) - self.version_too_old = torch.__version__ != test_ver and version_leq(torch.__version__, test_ver) + self.version_too_old = not is_module_ver_at_least(torch, pytorch_version_tuple) def __call__(self, obj): return unittest.skipIf( From 0ead56674fa0b03faca65866011e159f0e0a6b73 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 30 Jul 2021 10:37:55 +0100 Subject: [PATCH 172/176] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utils.py b/tests/utils.py index 8a3a4d3c23..988d40488a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -113,6 +113,7 @@ class SkipIfBeforePyTorchVersion: with PyTorch versions older than that given.""" def __init__(self, pytorch_version_tuple): + self.min_version = pytorch_version_tuple self.version_too_old = not is_module_ver_at_least(torch, pytorch_version_tuple) def __call__(self, obj): From 446d7ea87b24c606f8f1017d4afa44e98056ad5d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 30 Jul 2021 14:34:17 +0100 Subject: [PATCH 173/176] pt1.7 fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index da1278be15..8828dbd102 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -343,7 +343,7 @@ def _maximum(a, b): @staticmethod def _floor_div(a, b): - if is_module_ver_at_least(torch, (1, 7, 0)): + if is_module_ver_at_least(torch, (1, 8, 0)): return torch.div(a, b, rounding_mode="floor") return torch.floor_divide(a, b) From cb1f04722359c80264dbe5ce023a019de6fd195c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 30 Jul 2021 18:19:11 +0100 Subject: [PATCH 174/176] fix torch1.7 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 9957ef62c3..77ff5fe70b 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -35,7 +35,7 @@ ) from monai.utils import ensure_tuple, min_version, optional_import from monai.utils.enums import DataObjects -from monai.utils.misc import convert_data_type, dtype_convert +from monai.utils.misc import convert_data_type, dtype_convert, is_module_ver_at_least PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -683,7 +683,6 @@ def __call__( merge_channels: whether to use `np.any()` to merge the result on channel dim. if yes, will return a single channel mask with binary data. """ - if select_labels is None: select_labels = self.select_labels else: @@ -693,13 +692,16 @@ def __call__( data = img[[*select_labels]] else: where = np.where if isinstance(img, np.ndarray) else torch.where - data = where(self._in1d(img, select_labels), True, False).reshape(img.shape) + if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): + data = where(self._in1d(img, select_labels), True, False).reshape(img.shape) + else: + data = where(self._in1d(img, select_labels), 1, 0).reshape(img.shape) if merge_channels or self.merge_channels: - if isinstance(data, np.ndarray): - return np.any(data, axis=0, keepdims=True) # type: ignore + if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): + return data.any(0)[None] else: - return torch.any(data, dim=0, keepdim=True) + return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore return data From f30f72e9f044f272a4238bebb27ea051bd8048f6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 30 Jul 2021 18:24:32 +0100 Subject: [PATCH 175/176] post merge fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/array.py | 5 ++--- monai/transforms/utility/dictionary.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 998a589a52..eeb04997af 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -22,9 +22,8 @@ import numpy as np import torch -from monai.config import DtypeLike, NdarrayTensor -from monai.transforms.transform import Randomizable, RandomizableTransform, Transform -from monai.transforms.transform import NumpyTransform, Randomizable, TorchTransform, Transform +from monai.config import DtypeLike +from monai.transforms.transform import NumpyTransform, Randomizable, RandomizableTransform, TorchTransform, Transform from monai.transforms.utils import ( _unravel_index, convert_to_numpy, diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 5382939270..9d3214d790 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -23,7 +23,7 @@ import numpy as np import torch -from monai.config import DtypeLike, KeysCollection, NdarrayTensor +from monai.config import DtypeLike, KeysCollection from monai.data.utils import no_collation from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform From 228e3b401cf48196aa01d42348007177b6e38e68 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 10 Aug 2021 11:06:49 +0100 Subject: [PATCH 176/176] get number of conversions Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/compose.py | 23 ++++++++++++++++++++++ tests/test_compose.py | 38 ++++++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index b380f7d42a..a488f289de 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -12,6 +12,7 @@ A collection of generic interfaces for MONAI transforms. """ +import math import warnings from typing import Any, Callable, Optional, Sequence, Union @@ -22,8 +23,10 @@ # For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform) from monai.transforms.transform import ( # noqa: F401 MapTransform, + NumpyTransform, Randomizable, RandomizableTransform, + TorchTransform, Transform, apply_transform, ) @@ -168,3 +171,23 @@ def inverse(self, data): for t in reversed(invertible_transforms): data = apply_transform(t.inverse, data, self.map_items, self.unpack_items) return data + + def get_number_np_torch_conversions(self): + """Get the number of times that the data need to be converted from numpy to torch or vice versa. + Returns `math.nan` if any of transforms are neither `NumpyTransform` nor `TorchTransform`.""" + num_conversions = 0 + tr = self.flatten().transforms + # if any are unknown, return math.nan + if not all([isinstance(t, (NumpyTransform, TorchTransform)) for t in tr]): + return math.nan + # ignore torch or numpy transforms as they are ambivalent + tr = [t for t in tr if not (isinstance(t, NumpyTransform) and isinstance(t, TorchTransform))] + if len(tr) > 0: + t_prev = tr[0] + for t in tr[1:]: + if isinstance(t, NumpyTransform) and isinstance(t_prev, TorchTransform): + num_conversions += 1 + elif isinstance(t, TorchTransform) and isinstance(t_prev, NumpyTransform): + num_conversions += 1 + t_prev = t + return num_conversions diff --git a/tests/test_compose.py b/tests/test_compose.py index 28783cad23..a3c5f655e6 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -12,11 +12,38 @@ import sys import unittest +import numpy as np + from monai.data import DataLoader, Dataset from monai.transforms import AddChannel, Compose -from monai.transforms.transform import Randomizable +from monai.transforms.transform import NumpyTransform, Randomizable, TorchTransform, Transform from monai.utils import set_determinism +from parameterized import parameterized +import math + + +class Tr(Transform): + def __call__(self, x): + return x +class N(Tr, NumpyTransform): + pass +class T(Tr, TorchTransform): + pass +class NT(Tr, NumpyTransform, TorchTransform): + pass + + +TEST_NUMBER_CONVERSIONS = [ + ((N(), N()), 0), + ((T(), T()), 0), + ((NT(), NT()), 0), + ((N(), T()), 1), + ((N(), T(), N()), 2), + ((N(), NT(), T(), T(), NT(), NT(), N()), 2), + ((N(), NT(), T(), Compose([T(), NT(), NT(), N()])), 2), + ((Tr(), N(), T()), math.nan), +] class _RandXform(Randomizable): def randomize(self): @@ -217,6 +244,15 @@ def test_flatten_and_len(self): def test_backwards_compatible_imports(self): from monai.transforms.compose import MapTransform, RandomizableTransform, Transform # noqa: F401 + @parameterized.expand(TEST_NUMBER_CONVERSIONS) + def test_get_number_of_conversions(self, transforms, expected): + tr = Compose(transforms) + n = tr.get_number_np_torch_conversions() + if math.isnan(expected): + self.assertTrue(math.isnan(n)) + else: + self.assertEqual(n, expected) + if __name__ == "__main__": unittest.main()