From f9bc713c5ee0cfc48a58536a7790f029a69272ed Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 12 Aug 2021 03:23:23 +0800 Subject: [PATCH] 2691 Add HistogramIntensity transform (#2738) * [DLMED] add Histogram normalize Signed-off-by: Nic Ma --- docs/source/transforms.rst | 12 ++++++ monai/transforms/__init__.py | 5 +++ monai/transforms/intensity/array.py | 45 +++++++++++++++++++++- monai/transforms/intensity/dictionary.py | 48 ++++++++++++++++++++++++ monai/transforms/utils.py | 44 ++++++++++++++++++++++ tests/test_histogram_normalize.py | 47 +++++++++++++++++++++++ tests/test_histogram_normalized.py | 47 +++++++++++++++++++++++ 7 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 tests/test_histogram_normalize.py create mode 100644 tests/test_histogram_normalized.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 8a880ff151..f97be395d1 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -319,6 +319,12 @@ Intensity :members: :special-members: __call__ +`HistogramNormalize` +"""""""""""""""""""" + .. autoclass:: HistogramNormalize + :members: + :special-members: __call__ + IO ^^ @@ -930,6 +936,12 @@ Intensity (Dict) :members: :special-members: __call__ +`HistogramNormalized` +""""""""""""""""""""" + .. autoclass:: HistogramNormalized + :members: + :special-members: __call__ + IO (Dict) ^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 7f2873cc85..390b85a1b8 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -83,6 +83,7 @@ GaussianSharpen, GaussianSmooth, GibbsNoise, + HistogramNormalize, KSpaceSpikeNoise, MaskIntensity, NormalizeIntensity, @@ -120,6 +121,9 @@ GibbsNoised, GibbsNoiseD, GibbsNoiseDict, + HistogramNormalized, + HistogramNormalizeD, + HistogramNormalizeDict, KSpaceSpikeNoised, KSpaceSpikeNoiseD, KSpaceSpikeNoiseDict, @@ -467,6 +471,7 @@ create_scale, create_shear, create_translate, + equalize_hist, extreme_points_to_image, generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index c14f2b242f..258d896eb6 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -24,7 +24,7 @@ from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.transform import Fourier, RandomizableTransform, Transform -from monai.transforms.utils import is_positive, rescale_array +from monai.transforms.utils import equalize_hist, is_positive, rescale_array from monai.utils import ( PT_BEFORE_1_7, InvalidPyTorchVersionError, @@ -64,6 +64,7 @@ "KSpaceSpikeNoise", "RandKSpaceSpikeNoise", "RandCoarseDropout", + "HistogramNormalize", ] @@ -1626,3 +1627,45 @@ def __call__(self, img: np.ndarray): img[h] = self.fill_value return img + + +class HistogramNormalize(Transform): + """ + Apply the histogram normalization to input image. + Refer to: https://github.com/facebookresearch/CovidPrognosis/blob/master/covidprognosis/data/transforms.py#L83. + + Args: + num_bins: number of the bins to use in histogram, default to `256`. for more details: + https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. + min: the min value to normalize input image, default to `0`. + max: the max value to normalize input image, default to `255`. + mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`. + only points at which `mask==True` are used for the equalization. + can also provide the mask along with img at runtime. + dtype: data type of the output, default to `float32`. + + """ + + def __init__( + self, + num_bins: int = 256, + min: int = 0, + max: int = 255, + mask: Optional[np.ndarray] = None, + dtype: DtypeLike = np.float32, + ) -> None: + self.num_bins = num_bins + self.min = min + self.max = max + self.mask = mask + self.dtype = dtype + + def __call__(self, img: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarray: + return equalize_hist( + img=img, + mask=mask if mask is not None else self.mask, + num_bins=self.num_bins, + min=self.min, + max=self.max, + dtype=self.dtype, + ) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 19323e2020..bc5534b402 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -28,6 +28,7 @@ GaussianSharpen, GaussianSmooth, GibbsNoise, + HistogramNormalize, KSpaceSpikeNoise, MaskIntensity, NormalizeIntensity, @@ -72,6 +73,7 @@ "RandKSpaceSpikeNoised", "RandHistogramShiftd", "RandCoarseDropoutd", + "HistogramNormalized", "RandGaussianNoiseD", "RandGaussianNoiseDict", "ShiftIntensityD", @@ -122,6 +124,8 @@ "RandRicianNoiseDict", "RandCoarseDropoutD", "RandCoarseDropoutDict", + "HistogramNormalizeD", + "HistogramNormalizeDict", ] @@ -1469,6 +1473,49 @@ def __call__(self, data): return d +class HistogramNormalized(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.HistogramNormalize`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + num_bins: number of the bins to use in histogram, default to `256`. for more details: + https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. + min: the min value to normalize input image, default to `255`. + max: the max value to normalize input image, default to `255`. + mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`. + only points at which `mask==True` are used for the equalization. + can also provide the mask by `mask_key` at runtime. + mask_key: if mask is None, will try to get the mask with `mask_key`. + dtype: data type of the output, default to `float32`. + allow_missing_keys: do not raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + num_bins: int = 256, + min: int = 0, + max: int = 255, + mask: Optional[np.ndarray] = None, + mask_key: Optional[str] = None, + dtype: DtypeLike = np.float32, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.transform = HistogramNormalize(num_bins=num_bins, min=min, max=max, mask=mask, dtype=dtype) + self.mask_key = mask_key if mask is None else None + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.transform(d[key], d[self.mask_key]) if self.mask_key is not None else self.transform(d[key]) + + return d + + RandGaussianNoiseD = RandGaussianNoiseDict = RandGaussianNoised RandRicianNoiseD = RandRicianNoiseDict = RandRicianNoised ShiftIntensityD = ShiftIntensityDict = ShiftIntensityd @@ -1495,3 +1542,4 @@ def __call__(self, data): KSpaceSpikeNoiseD = KSpaceSpikeNoiseDict = KSpaceSpikeNoised RandKSpaceSpikeNoiseD = RandKSpaceSpikeNoiseDict = RandKSpaceSpikeNoised RandCoarseDropoutD = RandCoarseDropoutDict = RandCoarseDropoutd +HistogramNormalizeD = HistogramNormalizeDict = HistogramNormalized diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 800a779651..e996d7c9ea 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -40,6 +40,7 @@ ndimage, _ = optional_import("scipy.ndimage") cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") +exposure, has_skimage = optional_import("skimage.exposure") __all__ = [ "allow_missing_keys_mode", @@ -76,6 +77,7 @@ "tensor_to_numpy", "weighted_patch_samples", "zero_margins", + "equalize_hist", ] @@ -1115,3 +1117,45 @@ def tensor_to_numpy(data): return tuple(tensor_to_numpy(i) for i in data) return data + + +def equalize_hist( + img: np.ndarray, + mask: Optional[np.ndarray] = None, + num_bins: int = 256, + min: int = 0, + max: int = 255, + dtype: DtypeLike = np.float32, +) -> np.ndarray: + """ + Utility to equalize input image based on the histogram. + If `skimage` installed, will leverage `skimage.exposure.histogram`, otherwise, use + `np.histogram` instead. + + Args: + img: input image to equalize. + mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`. + only points at which `mask==True` are used for the equalization. + num_bins: number of the bins to use in histogram, default to `256`. for more details: + https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. + min: the min value to normalize input image, default to `0`. + max: the max value to normalize input image, default to `255`. + dtype: data type of the output, default to `float32`. + + """ + orig_shape = img.shape + hist_img = img[np.array(mask, dtype=bool)] if mask is not None else img + if has_skimage: + hist, bins = exposure.histogram(hist_img.flatten(), num_bins) + else: + hist, bins = np.histogram(hist_img.flatten(), num_bins) + bins = (bins[:-1] + bins[1:]) / 2 + + cum = hist.cumsum() + # normalize the cumulative result + cum = rescale_array(arr=cum, minv=min, maxv=max) + + # apply linear interpolation + img = np.interp(img.flatten(), bins, cum) + + return img.reshape(orig_shape).astype(dtype) diff --git a/tests/test_histogram_normalize.py b/tests/test_histogram_normalize.py new file mode 100644 index 0000000000..b69fb1d927 --- /dev/null +++ b/tests/test_histogram_normalize.py @@ -0,0 +1,47 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import HistogramNormalize + +TEST_CASE_1 = [ + {"num_bins": 4, "min": 1, "max": 5, "mask": np.array([1, 1, 1, 1, 1, 0])}, + np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), + np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), +] + +TEST_CASE_2 = [ + {"num_bins": 4, "max": 4, "dtype": np.uint8}, + np.array([0.0, 1.0, 2.0, 3.0, 4.0]), + np.array([0, 0, 1, 3, 4]), +] + +TEST_CASE_3 = [ + {"num_bins": 256, "max": 255, "dtype": np.uint8}, + np.array([[[100.0, 200.0], [150.0, 250.0]]]), + np.array([[[0, 170], [70, 255]]]), +] + + +class TestHistogramNormalize(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value(self, argments, image, expected_data): + result = HistogramNormalize(**argments)(image) + np.testing.assert_allclose(result, expected_data) + self.assertEqual(result.dtype, argments.get("dtype", np.float32)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_histogram_normalized.py b/tests/test_histogram_normalized.py new file mode 100644 index 0000000000..68647e82fb --- /dev/null +++ b/tests/test_histogram_normalized.py @@ -0,0 +1,47 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import HistogramNormalized + +TEST_CASE_1 = [ + {"keys": "img", "num_bins": 4, "min": 1, "max": 5, "mask_key": "mask"}, + {"img": np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), "mask": np.array([1, 1, 1, 1, 1, 0])}, + np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), +] + +TEST_CASE_2 = [ + {"keys": "img", "num_bins": 4, "max": 4, "dtype": np.uint8}, + {"img": np.array([0.0, 1.0, 2.0, 3.0, 4.0])}, + np.array([0, 0, 1, 3, 4]), +] + +TEST_CASE_3 = [ + {"keys": "img", "num_bins": 256, "max": 255, "dtype": np.uint8}, + {"img": np.array([[[100.0, 200.0], [150.0, 250.0]]])}, + np.array([[[0, 170], [70, 255]]]), +] + + +class TestHistogramNormalized(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_value(self, argments, image, expected_data): + result = HistogramNormalized(**argments)(image)["img"] + np.testing.assert_allclose(result, expected_data) + self.assertEqual(result.dtype, argments.get("dtype", np.float32)) + + +if __name__ == "__main__": + unittest.main()