From ef90fdce58830efcac328d7144302967129b9d86 Mon Sep 17 00:00:00 2001 From: DmitriyValetov Date: Wed, 13 Nov 2024 01:18:34 +0300 Subject: [PATCH 1/3] added support for single channel images --- torchcam/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchcam/utils.py b/torchcam/utils.py index 24b7419..d8c34b3 100644 --- a/torchcam/utils.py +++ b/torchcam/utils.py @@ -42,6 +42,10 @@ def overlay_mask(img: Image, mask: Image, colormap: str = "jet", alpha: float = cmap = cm.get_cmap(colormap) # Resize mask and apply colormap overlay = mask.resize(img.size, resample=Resampling.BICUBIC) - overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8) + if len(img.getbands()) == 1: + overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 0]).astype(np.uint8) + else: + overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8) + # Overlay the image with the mask return fromarray((alpha * np.asarray(img) + (1 - alpha) * cast(np.ndarray, overlay)).astype(np.uint8)) From 70f0c3329fe0f0900d11ae1a6b083029932781ea Mon Sep 17 00:00:00 2001 From: DmitriyValetov Date: Sat, 16 Nov 2024 23:50:12 +0300 Subject: [PATCH 2/3] refactor + test added --- tests/test_utils.py | 10 +++++++++- torchcam/utils.py | 8 ++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index e2c14ce..72ec6f1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,9 +5,9 @@ def test_overlay_mask(): + # RGB image img = Image.fromarray(np.zeros((4, 4, 3)).astype(np.uint8)) mask = Image.fromarray(255 * np.ones((4, 4)).astype(np.uint8)) - overlayed = utils.overlay_mask(img, mask, alpha=0.7) # Check object type @@ -16,3 +16,11 @@ def test_overlay_mask(): assert np.all(np.asarray(overlayed)[..., 0] == 0) assert np.all(np.asarray(overlayed)[..., 1] == 0) assert np.all(np.asarray(overlayed)[..., 2] == 39) + + # grayscale image + img = Image.fromarray(np.zeros((4, 4)).astype(np.uint8)) + mask = Image.fromarray(255 * np.ones((4, 4)).astype(np.uint8)) + overlayed = utils.overlay_mask(img, mask, alpha=0.7) + + # Verify value + assert np.all(np.asarray(overlayed)[...] == 39) \ No newline at end of file diff --git a/torchcam/utils.py b/torchcam/utils.py index d8c34b3..6b87e28 100644 --- a/torchcam/utils.py +++ b/torchcam/utils.py @@ -39,13 +39,13 @@ def overlay_mask(img: Image, mask: Image, colormap: str = "jet", alpha: float = if not isinstance(alpha, float) or alpha < 0 or alpha >= 1: raise ValueError("alpha argument is expected to be of type float between 0 and 1") + if not len(img.getbands()) in [1, 3]: + raise ValueError("img argument needs to be a grayscale or RGB image") + cmap = cm.get_cmap(colormap) # Resize mask and apply colormap overlay = mask.resize(img.size, resample=Resampling.BICUBIC) - if len(img.getbands()) == 1: - overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 0]).astype(np.uint8) - else: - overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8) + overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 2 if len(img.getbands())==1 else slice(0, 3)]).astype(np.uint8) # Overlay the image with the mask return fromarray((alpha * np.asarray(img) + (1 - alpha) * cast(np.ndarray, overlay)).astype(np.uint8)) From bf30196c54c374dd7f4c727b70a81cea9724574a Mon Sep 17 00:00:00 2001 From: DmitriyValetov Date: Tue, 19 Nov 2024 23:17:42 +0300 Subject: [PATCH 3/3] update to remarks --- tests/test_utils.py | 4 ++-- torchcam/utils.py | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 72ec6f1..e91a48e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -16,11 +16,11 @@ def test_overlay_mask(): assert np.all(np.asarray(overlayed)[..., 0] == 0) assert np.all(np.asarray(overlayed)[..., 1] == 0) assert np.all(np.asarray(overlayed)[..., 2] == 39) - + # grayscale image img = Image.fromarray(np.zeros((4, 4)).astype(np.uint8)) mask = Image.fromarray(255 * np.ones((4, 4)).astype(np.uint8)) overlayed = utils.overlay_mask(img, mask, alpha=0.7) # Verify value - assert np.all(np.asarray(overlayed)[...] == 39) \ No newline at end of file + assert np.all(np.asarray(overlayed) == 39) diff --git a/torchcam/utils.py b/torchcam/utils.py index 6b87e28..e67a0d0 100644 --- a/torchcam/utils.py +++ b/torchcam/utils.py @@ -39,13 +39,16 @@ def overlay_mask(img: Image, mask: Image, colormap: str = "jet", alpha: float = if not isinstance(alpha, float) or alpha < 0 or alpha >= 1: raise ValueError("alpha argument is expected to be of type float between 0 and 1") - if not len(img.getbands()) in [1, 3]: + if len(img.getbands()) not in {1, 3}: raise ValueError("img argument needs to be a grayscale or RGB image") cmap = cm.get_cmap(colormap) # Resize mask and apply colormap overlay = mask.resize(img.size, resample=Resampling.BICUBIC) - overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 2 if len(img.getbands())==1 else slice(0, 3)]).astype(np.uint8) + + overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 2 if len(img.getbands()) == 1 else slice(0, 3)]).astype( + np.uint8 + ) # Overlay the image with the mask return fromarray((alpha * np.asarray(img) + (1 - alpha) * cast(np.ndarray, overlay)).astype(np.uint8))