diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 6aa45a9f1d..bf6a19608a 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -941,7 +941,7 @@ class MaskIntensity(Transform): """ - backend = [TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, mask_data: Optional[NdarrayOrTensor] = None, select_fn: Callable = is_positive) -> None: self.mask_data = mask_data diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3366f17653..e30ef3241d 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -38,14 +38,13 @@ convert_to_numpy, convert_to_tensor, ensure_tuple, - get_equivalent_dtype, look_up_option, min_version, optional_import, ) from monai.utils.enums import TransformBackends from monai.utils.misc import is_module_ver_at_least -from monai.utils.type_conversion import convert_to_dst_type +from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -389,11 +388,9 @@ def __call__(self, data: NdarrayOrTensor): if applicable. """ - if self.data_type == "tensor": - dtype_ = get_equivalent_dtype(self.dtype, torch.Tensor) - return convert_to_tensor(data, dtype=dtype_, device=self.device) - dtype_ = get_equivalent_dtype(self.dtype, np.ndarray) - return convert_to_numpy(data, dtype=dtype_) + output_type = torch.Tensor if self.data_type == "tensor" else np.ndarray + out, *_ = convert_data_type(data, output_type=output_type, dtype=self.dtype, device=self.device) + return out class ToNumpy(Transform): @@ -880,18 +877,19 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform): and ET (Enhancing tumor). """ - def __call__(self, img: np.ndarray) -> np.ndarray: - img, *_ = convert_data_type(img, np.ndarray) # type: ignore + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: # if img has channel dim, squeeze it if img.ndim == 4 and img.shape[0] == 1: - img = np.squeeze(img, axis=0) + img = img.squeeze(0) - result = [np.logical_or(img == 1, img == 4)] + result = [(img == 1) | (img == 4)] # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT - result.append(np.logical_or(np.logical_or(img == 1, img == 4), img == 2)) + result.append((img == 1) | (img == 4) | (img == 2)) # label 4 is ET result.append(img == 4) - return np.stack(result, axis=0) + return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0) class AddExtremePointsChannel(Randomizable, Transform): @@ -966,6 +964,8 @@ class TorchVision: """ + backend = [TransformBackends.TORCH] + def __init__(self, name: str, *args, **kwargs) -> None: """ Args: @@ -978,14 +978,16 @@ def __init__(self, name: str, *args, **kwargs) -> None: transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) self.trans = transform(*args, **kwargs) - def __call__(self, img: torch.Tensor): + def __call__(self, img: NdarrayOrTensor): """ Args: img: PyTorch Tensor data for the TorchVision transform. """ - img, *_ = convert_data_type(img, torch.Tensor) # type: ignore - return self.trans(img) + img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore + out = self.trans(img_t) + out, *_ = convert_to_dst_type(src=out, dst=img) + return out class MapLabelValue: @@ -997,6 +999,8 @@ class MapLabelValue: """ + backend = [TransformBackends.NUMPY] + def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None: """ Args: @@ -1012,11 +1016,11 @@ def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeL self.orig_labels = orig_labels self.target_labels = target_labels - self.dtype = dtype + self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray) - def __call__(self, img: np.ndarray): - img, *_ = convert_data_type(img, np.ndarray) # type: ignore - img_flat = img.flatten() + def __call__(self, img: NdarrayOrTensor): + img_np, *_ = convert_data_type(img, np.ndarray) + img_flat = img_np.flatten() try: out_flat = np.copy(img_flat).astype(self.dtype) except ValueError: @@ -1028,7 +1032,9 @@ def __call__(self, img: np.ndarray): continue np.place(out_flat, img_flat == o, t) - return out_flat.reshape(img.shape) + out = out_flat.reshape(img_np.shape) + out, *_ = convert_to_dst_type(src=out, dst=img, dtype=self.dtype) + return out class IntensityStats(Transform): @@ -1050,14 +1056,16 @@ class IntensityStats(Transform): """ + backend = [TransformBackends.NUMPY] + def __init__(self, ops: Sequence[Union[str, Callable]], key_prefix: str, channel_wise: bool = False) -> None: self.ops = ensure_tuple(ops) self.key_prefix = key_prefix self.channel_wise = channel_wise def __call__( - self, img: np.ndarray, meta_data: Optional[Dict] = None, mask: Optional[np.ndarray] = None - ) -> Tuple[np.ndarray, Dict]: + self, img: NdarrayOrTensor, meta_data: Optional[Dict] = None, mask: Optional[np.ndarray] = None + ) -> Tuple[NdarrayOrTensor, Dict]: """ Compute statistics for the intensity of input image. @@ -1068,15 +1076,15 @@ def __call__( mask must have the same shape as input `img`. """ - img, *_ = convert_data_type(img, np.ndarray) # type: ignore + img_np: np.ndarray + img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore if meta_data is None: meta_data = {} - img_: np.ndarray = img if mask is not None: - if mask.shape != img.shape or mask.dtype != bool: + if mask.shape != img_np.shape or mask.dtype != bool: raise TypeError("mask must be bool array with the same shape as input `img`.") - img_ = img[mask] + img_np = img_np[mask] supported_ops = { "mean": np.nanmean, @@ -1095,9 +1103,9 @@ def _compute(op: Callable, data: np.ndarray): for o in self.ops: if isinstance(o, str): o = look_up_option(o, supported_ops.keys()) - meta_data[self.key_prefix + "_" + o] = _compute(supported_ops[o], img_) # type: ignore + meta_data[self.key_prefix + "_" + o] = _compute(supported_ops[o], img_np) # type: ignore elif callable(o): - meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_) + meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_np) custom_index += 1 else: raise ValueError("ops must be key string for predefined operations or callable function.") diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index c4031b43f0..c58f569b8f 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1192,11 +1192,13 @@ class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): and ET (Enhancing tumor). """ + backend = ConvertToMultiChannelBasedOnBratsClasses.backend + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): super().__init__(keys, allow_missing_keys) self.converter = ConvertToMultiChannelBasedOnBratsClasses() - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -1280,6 +1282,8 @@ class TorchVisiond(MapTransform): data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor. """ + backend = TorchVision.backend + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: """ Args: @@ -1294,7 +1298,7 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F super().__init__(keys, allow_missing_keys) self.trans = TorchVision(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.trans(d[key]) @@ -1317,6 +1321,8 @@ class RandTorchVisiond(Randomizable, MapTransform): """ + backend = TorchVision.backend + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: """ Args: @@ -1331,7 +1337,7 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F MapTransform.__init__(self, keys, allow_missing_keys) self.trans = TorchVision(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.trans(d[key]) @@ -1343,6 +1349,8 @@ class MapLabelValued(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`. """ + backend = MapLabelValue.backend + def __init__( self, keys: KeysCollection, @@ -1364,7 +1372,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels, dtype=dtype) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.mapper(d[key]) @@ -1406,6 +1414,8 @@ class IntensityStatsd(MapTransform): """ + backend = IntensityStats.backend + def __init__( self, keys: KeysCollection, @@ -1425,7 +1435,7 @@ def __init__( raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - def __call__(self, data) -> Dict[Hashable, np.ndarray]: + def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key, mask_key, meta_key, meta_key_postfix in self.key_iterator( d, self.mask_keys, self.meta_keys, self.meta_key_postfix @@ -1442,7 +1452,7 @@ class ToDeviced(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToDevice`. """ - backend = [TransformBackends.TORCH] + backend = ToDevice.backend def __init__( self, keys: KeysCollection, device: Union[torch.device, str], allow_missing_keys: bool = False, **kwargs diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index b7f067076c..240d1138ad 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -244,7 +244,7 @@ def convert_data_type( output_type = output_type or orig_type - dtype_ = get_equivalent_dtype(dtype or get_dtype(data), output_type) + dtype_ = get_equivalent_dtype(dtype, output_type) if output_type is torch.Tensor: data = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence) diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py index aba10fd717..28f9fbd1bd 100644 --- a/tests/test_convert_data_type.py +++ b/tests/test_convert_data_type.py @@ -60,11 +60,6 @@ def test_convert_data_type(self, in_image, im_out): def test_neg_stride(self): _ = convert_data_type(np.array((1, 2))[::-1], torch.Tensor) - def test_ill_arg(self): - with self.assertRaises(ValueError): - convert_data_type(None, torch.Tensor) - convert_data_type(None, np.ndarray) - @parameterized.expand(TESTS_LIST) def test_convert_list(self, in_image, im_out, wrap): output_type = type(im_out) if wrap else type(im_out[0]) diff --git a/tests/test_convert_to_multi_channel.py b/tests/test_convert_to_multi_channel.py index 2f7a38e6e4..4892eae809 100644 --- a/tests/test_convert_to_multi_channel.py +++ b/tests/test_convert_to_multi_channel.py @@ -11,34 +11,46 @@ import unittest -import numpy as np +import torch from parameterized import parameterized from monai.transforms import ConvertToMultiChannelBasedOnBratsClasses +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]]), - np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]), -] - -TEST_CASE_2 = [ - np.array([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]]), - np.array( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.extend( [ - [[[0, 1], [1, 0]], [[0, 1], [1, 1]]], - [[[0, 1], [1, 1]], [[1, 1], [1, 1]]], - [[[0, 0], [0, 0]], [[0, 1], [1, 1]]], + [ + p([[0, 1, 2], [1, 2, 4], [0, 1, 4]]), + p( + [ + [[0, 1, 0], [1, 0, 1], [0, 1, 1]], + [[0, 1, 1], [1, 1, 1], [0, 1, 1]], + [[0, 0, 0], [0, 0, 1], [0, 0, 1]], + ] + ), + ], + [ + p([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]]), + p( + [ + [[[0, 1], [1, 0]], [[0, 1], [1, 1]]], + [[[0, 1], [1, 1]], [[1, 1], [1, 1]]], + [[[0, 0], [0, 0]], [[0, 1], [1, 1]]], + ] + ), + ], ] - ), -] + ) class TestConvertToMultiChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand(TESTS) def test_type_shape(self, data, expected_result): result = ConvertToMultiChannelBasedOnBratsClasses()(data) - np.testing.assert_equal(result, expected_result) - self.assertEqual(f"{result.dtype}", "bool") + assert_allclose(result, expected_result) + self.assertTrue(result.dtype in (bool, torch.bool)) if __name__ == "__main__": diff --git a/tests/test_intensity_stats.py b/tests/test_intensity_stats.py index 2647efd7c2..6fe8237f00 100644 --- a/tests/test_intensity_stats.py +++ b/tests/test_intensity_stats.py @@ -15,40 +15,43 @@ from parameterized import parameterized from monai.transforms import IntensityStats +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"ops": ["max", "mean"], "key_prefix": "orig"}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - {"affine": None}, - {"orig_max": 3.0, "orig_mean": 1.5}, -] - -TEST_CASE_2 = [{"ops": "std", "key_prefix": "orig"}, np.array([[[0.0, 1.0], [2.0, 3.0]]]), None, {"orig_std": 1.118034}] - -TEST_CASE_3 = [ - {"ops": [np.mean, "max", np.min], "key_prefix": "orig"}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - None, - {"orig_custom_0": 1.5, "orig_max": 3.0, "orig_custom_1": 0.0}, -] - -TEST_CASE_4 = [ - {"ops": ["max", "mean"], "key_prefix": "orig", "channel_wise": True}, - np.array([[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]]), - {"affine": None}, - {"orig_max": [3.0, 7.0], "orig_mean": [1.5, 5.5]}, -] - -TEST_CASE_5 = [ - {"ops": ["max", "mean"], "key_prefix": "orig"}, - np.array([[[0.0, 1.0], [2.0, 3.0]]]), - {"affine": None}, - {"orig_max": 3.0, "orig_mean": 1.5}, -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.extend( + [ + [ + {"ops": ["max", "mean"], "key_prefix": "orig"}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + {"affine": None}, + {"orig_max": 3.0, "orig_mean": 1.5}, + ], + [{"ops": "std", "key_prefix": "orig"}, p([[[0.0, 1.0], [2.0, 3.0]]]), None, {"orig_std": 1.118034}], + [ + {"ops": [np.mean, "max", np.min], "key_prefix": "orig"}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + None, + {"orig_custom_0": 1.5, "orig_max": 3.0, "orig_custom_1": 0.0}, + ], + [ + {"ops": ["max", "mean"], "key_prefix": "orig", "channel_wise": True}, + p([[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]]), + {"affine": None}, + {"orig_max": [3.0, 7.0], "orig_mean": [1.5, 5.5]}, + ], + [ + {"ops": ["max", "mean"], "key_prefix": "orig"}, + p([[[0.0, 1.0], [2.0, 3.0]]]), + {"affine": None}, + {"orig_max": 3.0, "orig_mean": 1.5}, + ], + ] + ) class TestIntensityStats(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, input_param, img, meta_dict, expected): _, meta_dict = IntensityStats(**input_param)(img, meta_dict) for k, v in expected.items(): @@ -56,11 +59,12 @@ def test_value(self, input_param, img, meta_dict, expected): np.testing.assert_allclose(v, meta_dict[k], atol=1e-3) def test_mask(self): - img = np.array([[[0.0, 1.0], [2.0, 3.0]]]) - mask = np.array([[[1, 0], [1, 0]]], dtype=bool) - img, meta_dict = IntensityStats(ops=["max", "mean"], key_prefix="orig")(img, mask=mask) - np.testing.assert_allclose(meta_dict["orig_max"], 2.0, atol=1e-3) - np.testing.assert_allclose(meta_dict["orig_mean"], 1.0, atol=1e-3) + for p in TEST_NDARRAYS: + img = p([[[0.0, 1.0], [2.0, 3.0]]]) + mask = np.array([[[1, 0], [1, 0]]], dtype=bool) + img, meta_dict = IntensityStats(ops=["max", "mean"], key_prefix="orig")(img, mask=mask) + np.testing.assert_allclose(meta_dict["orig_max"], 2.0, atol=1e-3) + np.testing.assert_allclose(meta_dict["orig_mean"], 1.0, atol=1e-3) if __name__ == "__main__": diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py index 388b6db973..2de549ad23 100644 --- a/tests/test_map_label_value.py +++ b/tests/test_map_label_value.py @@ -12,62 +12,69 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import MapLabelValue +from monai.utils import PT_BEFORE_1_7 +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, - np.array([[3, 1], [1, 2]]), - np.array([[0, 2], [2, 1]]), -] - -TEST_CASE_2 = [ - {"orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]}, - np.array([[[3], [5], [5], [8]]]), - np.array([[[0], [1], [1], [2]]]), -] - -TEST_CASE_3 = [{"orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]}, np.array([3, 1, 1, 2]), np.array([2, 0, 0, 1])] - -TEST_CASE_4 = [ - {"orig_labels": [1, 2, 3], "target_labels": [0.5, 1.5, 2.5]}, - np.array([3, 1, 1, 2]), - np.array([2.5, 0.5, 0.5, 1.5]), -] - -TEST_CASE_5 = [ - {"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8}, - np.array([3.5, 1.5, 1.5, 2.5]), - np.array([2, 0, 0, 1]), -] - -TEST_CASE_6 = [ - {"orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]}, - np.array([["label3", "label1"], ["label1", "label2"]]), - np.array([[0, 2], [2, 1]]), -] - -TEST_CASE_7 = [ - {"orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": "str"}, - np.array([[3.5, 1.5], [1.5, 2.5]]), - np.array([["label0", "label2"], ["label2", "label1"]]), -] - -TEST_CASE_8 = [ - {"orig_labels": ["label3", "label2", "label1"], "target_labels": ["label1", "label2", "label3"], "dtype": "str"}, - np.array([["label3", "label1"], ["label1", "label2"]]), - np.array([["label1", "label3"], ["label3", "label2"]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.extend( + [ + [{"orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, p([[3, 1], [1, 2]]), p([[0.0, 2.0], [2.0, 1.0]])], + [ + {"orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]}, + p([[[3], [5], [5], [8]]]), + p([[[0.0], [1.0], [1.0], [2.0]]]), + ], + [{"orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]}, p([3, 1, 1, 2]), p([2.0, 0.0, 0.0, 1.0])], + [{"orig_labels": [1, 2, 3], "target_labels": [0.5, 1.5, 2.5]}, p([3, 1, 1, 2]), p([2.5, 0.5, 0.5, 1.5])], + ] + ) + # PyTorch 1.5.1 doesn't support rich dtypes + if not PT_BEFORE_1_7: + TESTS.append( + [ + {"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8}, + p([3.5, 1.5, 1.5, 2.5]), + p([2, 0, 0, 1]), + ] + ) +TESTS.extend( + [ + [ + {"orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]}, + np.array([["label3", "label1"], ["label1", "label2"]]), + np.array([[0, 2], [2, 1]]), + ], + [ + {"orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": "str"}, + np.array([[3.5, 1.5], [1.5, 2.5]]), + np.array([["label0", "label2"], ["label2", "label1"]]), + ], + [ + { + "orig_labels": ["label3", "label2", "label1"], + "target_labels": ["label1", "label2", "label3"], + "dtype": "str", + }, + np.array([["label3", "label1"], ["label1", "label2"]]), + np.array([["label1", "label3"], ["label3", "label2"]]), + ], + ] +) class TestMapLabelValue(unittest.TestCase): - @parameterized.expand( - [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] - ) + @parameterized.expand(TESTS) def test_shape(self, input_param, input_data, expected_value): result = MapLabelValue(**input_param)(input_data) - np.testing.assert_equal(result, expected_value) + if isinstance(expected_value, torch.Tensor): + torch.testing.assert_allclose(result, expected_value) + else: + np.testing.assert_equal(result, expected_value) self.assertTupleEqual(result.shape, expected_value.shape) diff --git a/tests/test_torchvision.py b/tests/test_torchvision.py index 58e7d9295f..8973ad523f 100644 --- a/tests/test_torchvision.py +++ b/tests/test_torchvision.py @@ -11,51 +11,54 @@ import unittest -import torch from parameterized import parameterized from monai.transforms import TorchVision from monai.utils import set_determinism -from tests.utils import SkipIfBeforePyTorchVersion - -TEST_CASE_1 = [ - {"name": "ColorJitter"}, - torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), - torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), -] - -TEST_CASE_2 = [ - {"name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5}, - torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), - torch.tensor( - [ - [[0.1090, 0.6193], [0.6193, 0.9164]], - [[0.1090, 0.6193], [0.6193, 0.9164]], - [[0.1090, 0.6193], [0.6193, 0.9164]], - ] - ), -] +from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose -TEST_CASE_3 = [ - {"name": "Pad", "padding": [1, 1, 1, 1]}, - torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), - torch.tensor( +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.extend( [ - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [ + {"name": "ColorJitter"}, + p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + ], + [ + {"name": "ColorJitter", "brightness": 0.5, "contrast": 0.5, "saturation": [0.1, 0.8], "hue": 0.5}, + p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + p( + [ + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + [[0.1090, 0.6193], [0.6193, 0.9164]], + ] + ), + ], + [ + {"name": "Pad", "padding": [1, 1, 1, 1]}, + p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]), + p( + [ + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + ] + ), + ], ] - ), -] + ) @SkipIfBeforePyTorchVersion((1, 7)) class TestTorchVision(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, input_param, input_data, expected_value): set_determinism(seed=0) result = TorchVision(**input_param)(input_data) - torch.testing.assert_allclose(result, expected_value) + assert_allclose(result, expected_value, rtol=1e-3) if __name__ == "__main__":