Skip to content

Commit

Permalink
2691 Add HistogramIntensity transform (#2738)
Browse files Browse the repository at this point in the history
* [DLMED] add Histogram normalize

Signed-off-by: Nic Ma <nma@nvidia.com>
  • Loading branch information
Nic-Ma authored Aug 11, 2021
1 parent 33184fa commit f9bc713
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 1 deletion.
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,12 @@ Intensity
:members:
:special-members: __call__

`HistogramNormalize`
""""""""""""""""""""
.. autoclass:: HistogramNormalize
:members:
:special-members: __call__


IO
^^
Expand Down Expand Up @@ -930,6 +936,12 @@ Intensity (Dict)
:members:
:special-members: __call__

`HistogramNormalized`
"""""""""""""""""""""
.. autoclass:: HistogramNormalized
:members:
:special-members: __call__


IO (Dict)
^^^^^^^^^
Expand Down
5 changes: 5 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
GaussianSharpen,
GaussianSmooth,
GibbsNoise,
HistogramNormalize,
KSpaceSpikeNoise,
MaskIntensity,
NormalizeIntensity,
Expand Down Expand Up @@ -120,6 +121,9 @@
GibbsNoised,
GibbsNoiseD,
GibbsNoiseDict,
HistogramNormalized,
HistogramNormalizeD,
HistogramNormalizeDict,
KSpaceSpikeNoised,
KSpaceSpikeNoiseD,
KSpaceSpikeNoiseDict,
Expand Down Expand Up @@ -467,6 +471,7 @@
create_scale,
create_shear,
create_translate,
equalize_hist,
extreme_points_to_image,
generate_label_classes_crop_centers,
generate_pos_neg_label_crop_centers,
Expand Down
45 changes: 44 additions & 1 deletion monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
from monai.transforms.transform import Fourier, RandomizableTransform, Transform
from monai.transforms.utils import is_positive, rescale_array
from monai.transforms.utils import equalize_hist, is_positive, rescale_array
from monai.utils import (
PT_BEFORE_1_7,
InvalidPyTorchVersionError,
Expand Down Expand Up @@ -64,6 +64,7 @@
"KSpaceSpikeNoise",
"RandKSpaceSpikeNoise",
"RandCoarseDropout",
"HistogramNormalize",
]


Expand Down Expand Up @@ -1626,3 +1627,45 @@ def __call__(self, img: np.ndarray):
img[h] = self.fill_value

return img


class HistogramNormalize(Transform):
"""
Apply the histogram normalization to input image.
Refer to: https://github.com/facebookresearch/CovidPrognosis/blob/master/covidprognosis/data/transforms.py#L83.
Args:
num_bins: number of the bins to use in histogram, default to `256`. for more details:
https://numpy.org/doc/stable/reference/generated/numpy.histogram.html.
min: the min value to normalize input image, default to `0`.
max: the max value to normalize input image, default to `255`.
mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`.
only points at which `mask==True` are used for the equalization.
can also provide the mask along with img at runtime.
dtype: data type of the output, default to `float32`.
"""

def __init__(
self,
num_bins: int = 256,
min: int = 0,
max: int = 255,
mask: Optional[np.ndarray] = None,
dtype: DtypeLike = np.float32,
) -> None:
self.num_bins = num_bins
self.min = min
self.max = max
self.mask = mask
self.dtype = dtype

def __call__(self, img: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarray:
return equalize_hist(
img=img,
mask=mask if mask is not None else self.mask,
num_bins=self.num_bins,
min=self.min,
max=self.max,
dtype=self.dtype,
)
48 changes: 48 additions & 0 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
GaussianSharpen,
GaussianSmooth,
GibbsNoise,
HistogramNormalize,
KSpaceSpikeNoise,
MaskIntensity,
NormalizeIntensity,
Expand Down Expand Up @@ -72,6 +73,7 @@
"RandKSpaceSpikeNoised",
"RandHistogramShiftd",
"RandCoarseDropoutd",
"HistogramNormalized",
"RandGaussianNoiseD",
"RandGaussianNoiseDict",
"ShiftIntensityD",
Expand Down Expand Up @@ -122,6 +124,8 @@
"RandRicianNoiseDict",
"RandCoarseDropoutD",
"RandCoarseDropoutDict",
"HistogramNormalizeD",
"HistogramNormalizeDict",
]


Expand Down Expand Up @@ -1469,6 +1473,49 @@ def __call__(self, data):
return d


class HistogramNormalized(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.HistogramNormalize`.
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
num_bins: number of the bins to use in histogram, default to `256`. for more details:
https://numpy.org/doc/stable/reference/generated/numpy.histogram.html.
min: the min value to normalize input image, default to `255`.
max: the max value to normalize input image, default to `255`.
mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`.
only points at which `mask==True` are used for the equalization.
can also provide the mask by `mask_key` at runtime.
mask_key: if mask is None, will try to get the mask with `mask_key`.
dtype: data type of the output, default to `float32`.
allow_missing_keys: do not raise exception if key is missing.
"""

def __init__(
self,
keys: KeysCollection,
num_bins: int = 256,
min: int = 0,
max: int = 255,
mask: Optional[np.ndarray] = None,
mask_key: Optional[str] = None,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.transform = HistogramNormalize(num_bins=num_bins, min=min, max=max, mask=mask, dtype=dtype)
self.mask_key = mask_key if mask is None else None

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.transform(d[key], d[self.mask_key]) if self.mask_key is not None else self.transform(d[key])

return d


RandGaussianNoiseD = RandGaussianNoiseDict = RandGaussianNoised
RandRicianNoiseD = RandRicianNoiseDict = RandRicianNoised
ShiftIntensityD = ShiftIntensityDict = ShiftIntensityd
Expand All @@ -1495,3 +1542,4 @@ def __call__(self, data):
KSpaceSpikeNoiseD = KSpaceSpikeNoiseDict = KSpaceSpikeNoised
RandKSpaceSpikeNoiseD = RandKSpaceSpikeNoiseDict = RandKSpaceSpikeNoised
RandCoarseDropoutD = RandCoarseDropoutDict = RandCoarseDropoutd
HistogramNormalizeD = HistogramNormalizeDict = HistogramNormalized
44 changes: 44 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ndimage, _ = optional_import("scipy.ndimage")
cp, has_cp = optional_import("cupy")
cp_ndarray, _ = optional_import("cupy", name="ndarray")
exposure, has_skimage = optional_import("skimage.exposure")

__all__ = [
"allow_missing_keys_mode",
Expand Down Expand Up @@ -76,6 +77,7 @@
"tensor_to_numpy",
"weighted_patch_samples",
"zero_margins",
"equalize_hist",
]


Expand Down Expand Up @@ -1115,3 +1117,45 @@ def tensor_to_numpy(data):
return tuple(tensor_to_numpy(i) for i in data)

return data


def equalize_hist(
img: np.ndarray,
mask: Optional[np.ndarray] = None,
num_bins: int = 256,
min: int = 0,
max: int = 255,
dtype: DtypeLike = np.float32,
) -> np.ndarray:
"""
Utility to equalize input image based on the histogram.
If `skimage` installed, will leverage `skimage.exposure.histogram`, otherwise, use
`np.histogram` instead.
Args:
img: input image to equalize.
mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`.
only points at which `mask==True` are used for the equalization.
num_bins: number of the bins to use in histogram, default to `256`. for more details:
https://numpy.org/doc/stable/reference/generated/numpy.histogram.html.
min: the min value to normalize input image, default to `0`.
max: the max value to normalize input image, default to `255`.
dtype: data type of the output, default to `float32`.
"""
orig_shape = img.shape
hist_img = img[np.array(mask, dtype=bool)] if mask is not None else img
if has_skimage:
hist, bins = exposure.histogram(hist_img.flatten(), num_bins)
else:
hist, bins = np.histogram(hist_img.flatten(), num_bins)
bins = (bins[:-1] + bins[1:]) / 2

cum = hist.cumsum()
# normalize the cumulative result
cum = rescale_array(arr=cum, minv=min, maxv=max)

# apply linear interpolation
img = np.interp(img.flatten(), bins, cum)

return img.reshape(orig_shape).astype(dtype)
47 changes: 47 additions & 0 deletions tests/test_histogram_normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import HistogramNormalize

TEST_CASE_1 = [
{"num_bins": 4, "min": 1, "max": 5, "mask": np.array([1, 1, 1, 1, 1, 0])},
np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]),
np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]),
]

TEST_CASE_2 = [
{"num_bins": 4, "max": 4, "dtype": np.uint8},
np.array([0.0, 1.0, 2.0, 3.0, 4.0]),
np.array([0, 0, 1, 3, 4]),
]

TEST_CASE_3 = [
{"num_bins": 256, "max": 255, "dtype": np.uint8},
np.array([[[100.0, 200.0], [150.0, 250.0]]]),
np.array([[[0, 170], [70, 255]]]),
]


class TestHistogramNormalize(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_value(self, argments, image, expected_data):
result = HistogramNormalize(**argments)(image)
np.testing.assert_allclose(result, expected_data)
self.assertEqual(result.dtype, argments.get("dtype", np.float32))


if __name__ == "__main__":
unittest.main()
47 changes: 47 additions & 0 deletions tests/test_histogram_normalized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import HistogramNormalized

TEST_CASE_1 = [
{"keys": "img", "num_bins": 4, "min": 1, "max": 5, "mask_key": "mask"},
{"img": np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), "mask": np.array([1, 1, 1, 1, 1, 0])},
np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]),
]

TEST_CASE_2 = [
{"keys": "img", "num_bins": 4, "max": 4, "dtype": np.uint8},
{"img": np.array([0.0, 1.0, 2.0, 3.0, 4.0])},
np.array([0, 0, 1, 3, 4]),
]

TEST_CASE_3 = [
{"keys": "img", "num_bins": 256, "max": 255, "dtype": np.uint8},
{"img": np.array([[[100.0, 200.0], [150.0, 250.0]]])},
np.array([[[0, 170], [70, 255]]]),
]


class TestHistogramNormalized(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_value(self, argments, image, expected_data):
result = HistogramNormalized(**argments)(image)["img"]
np.testing.assert_allclose(result, expected_data)
self.assertEqual(result.dtype, argments.get("dtype", np.float32))


if __name__ == "__main__":
unittest.main()

0 comments on commit f9bc713

Please sign in to comment.