From c3649ca2c2ad08e24fc7fc5e397cc21b5f0640d8 Mon Sep 17 00:00:00 2001 From: HoseinAkbarzadeh Date: Wed, 29 Nov 2023 13:50:22 +0000 Subject: [PATCH 01/21] SpatialCorrelationCoefficient functionality and module added. --- src/torchmetrics/functional/image/__init__.py | 2 + src/torchmetrics/functional/image/scc.py | 118 ++++++++++++++++++ src/torchmetrics/image/__init__.py | 2 + src/torchmetrics/image/scc.py | 40 ++++++ 4 files changed, 162 insertions(+) create mode 100644 src/torchmetrics/functional/image/scc.py create mode 100644 src/torchmetrics/image/scc.py diff --git a/src/torchmetrics/functional/image/__init__.py b/src/torchmetrics/functional/image/__init__.py index 329b33b66fe..41c4090a542 100644 --- a/src/torchmetrics/functional/image/__init__.py +++ b/src/torchmetrics/functional/image/__init__.py @@ -21,6 +21,7 @@ from torchmetrics.functional.image.rase import relative_average_spectral_error from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window from torchmetrics.functional.image.sam import spectral_angle_mapper +from torchmetrics.functional.image.scc import spatial_correlation_coefficient from torchmetrics.functional.image.ssim import ( multiscale_structural_similarity_index_measure, structural_similarity_index_measure, @@ -45,4 +46,5 @@ "visual_information_fidelity", "learned_perceptual_image_patch_similarity", "perceptual_path_length", + "spatial_correlation_coefficient" ] diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py new file mode 100644 index 00000000000..6c40df40047 --- /dev/null +++ b/src/torchmetrics/functional/image/scc.py @@ -0,0 +1,118 @@ + +from typing import Union, Tuple, Optional +from typing_extensions import Literal + +import torch +from torch import Tensor, tensor +from torch.nn.functional import conv2d + +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.distributed import reduce + +def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tuple[Tensor, Tensor, Tensor]: + """Update and returns variables required to compute Spatial Correlation Coefficient. + + Args: + preds: Predicted tensor + target: Ground truth tensor + hp_filter: High-pass filter tensor + window_size: Local window size integer + """ + if preds.dtype != target.dtype: + target = target.to(preds.dtype) + _check_same_shape(preds, target) + if len(preds.shape) not in (3, 4): + raise ValueError( + "Expected `preds` and `target` to have batch of colored images with BxCxHxW shape" + " or batch of grayscale images of BxHxW shape." + f" Got preds: {preds.shape} and target: {target.shape}." + ) + + if len(preds.shape) == 3: + preds = preds.unsqueeze(1) + target = target.unsqueeze(1) + + if not window_size > 0: + raise ValueError( + f"Expected `window_size` to be a positive integer. Got {window_size}." + ) + + if window_size > preds.size(2) or window_size > preds.size(3): + raise ValueError( + f"Expected `window_size` to be less than or equal to the size of the image." + f" Got window_size: {window_size} and image size: {preds.size(2)}x{preds.size(3)}." + ) + + preds = preds.to(torch.float32) + target = target.to(torch.float32) + hp_filter = hp_filter[None, None, :].to(dtype=preds.dtype, device=preds.device) + return preds, target, hp_filter + +def _symmetric_reflect_pad_2d(input: Tensor, pad: Union[int, Tuple[int, ...]]) -> Tensor: + if isinstance(pad, int): + pad = (pad, pad, pad, pad) + assert len(pad) == 4 + + left_pad = input[:, :, :, 0:pad[0]].flip(dims=[3]) + right_pad = input[:, :, :, -pad[1]:].flip(dims=[3]) + padded = torch.cat([left_pad, input, right_pad], dim=3) + + top_pad = padded[:, :, 0:pad[2], :].flip(dims=[2]) + bottom_pad = padded[:, :, -pad[3]:, :].flip(dims=[2]) + padded = torch.cat([top_pad, padded, bottom_pad], dim=2) + return padded + +def _signal_convolve_2d(input: Tensor, kernel: Tensor) -> Tensor: + left_padding = int(torch.floor(tensor((kernel.size(3)-1)/2)).item()) + right_padding = int(torch.ceil(tensor((kernel.size(3)-1)/2)).item()) + top_padding = int(torch.floor(tensor((kernel.size(2)-1)/2)).item()) + bottom_padding = int(torch.ceil(tensor((kernel.size(2)-1)/2)).item()) + + padded = _symmetric_reflect_pad_2d(input, pad=(left_padding, right_padding, top_padding, bottom_padding)) + kernel = kernel.flip([2, 3]) + out = conv2d(padded, kernel, stride=1, padding=0) + return out + +def _hp_2d_laplacian(input: Tensor, kernel: Tensor) -> Tensor: + output = _signal_convolve_2d(input, kernel) + output += _signal_convolve_2d(input, kernel) + return output + +def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor): + preds_mean = conv2d(preds, window, stride=1, padding='same') + target_mean = conv2d(target, window, stride=1, padding='same') + + preds_var = conv2d(preds**2, window, stride=1, padding='same') - preds_mean**2 + target_var = conv2d(target**2, window, stride=1, padding='same') - target_mean**2 + target_preds_cov = conv2d(target*preds, window, stride=1, padding='same') - target_mean*preds_mean + + return preds_var, target_var, target_preds_cov + +def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int): + dtype = preds.dtype + device = preds.device + + window = torch.ones(size=(1, 1, window_size, window_size), dtype=dtype, device=device)/( window_size**2 ) + + preds_hp = _hp_2d_laplacian(preds, hp_filter) + target_hp = _hp_2d_laplacian(target, hp_filter) + + preds_var, target_var, target_preds_cov = _local_variance_covariance(preds_hp, target_hp, window) + + preds_var[preds_var<0] = 0 + target_var[target_var<0] = 0 + + den = torch.sqrt(target_var) * torch.sqrt(preds_var) + idx = (den==0) + den[den == 0] = 1 + scc = target_preds_cov / den + scc[idx] = 0 + return scc + +def spatial_correlation_coefficient(preds: Tensor, target: Tensor, + hp_filter: Tensor = tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]), + window_size: int = 8): + preds, target, hp_filter = _scc_update(preds, target, hp_filter, window_size) + + per_channel = [_scc_per_channel_compute(preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, window_size) for i in range(preds.size(1))] + return reduce(torch.cat(per_channel, dim=1), reduction='elementwise_mean') diff --git a/src/torchmetrics/image/__init__.py b/src/torchmetrics/image/__init__.py index 1defa78bbf5..623fb32af23 100644 --- a/src/torchmetrics/image/__init__.py +++ b/src/torchmetrics/image/__init__.py @@ -23,6 +23,7 @@ from torchmetrics.image.tv import TotalVariation from torchmetrics.image.uqi import UniversalImageQualityIndex from torchmetrics.image.vif import VisualInformationFidelity +from torchmetrics.image.scc import SpatialCorrelationCoefficient from torchmetrics.utilities.imports import ( _TORCH_FIDELITY_AVAILABLE, _TORCHVISION_AVAILABLE, @@ -42,6 +43,7 @@ "UniversalImageQualityIndex", "VisualInformationFidelity", "TotalVariation", + "SpatialCorrelationCoefficient", ] if _TORCH_FIDELITY_AVAILABLE: diff --git a/src/torchmetrics/image/scc.py b/src/torchmetrics/image/scc.py new file mode 100644 index 00000000000..7fe6e5d0754 --- /dev/null +++ b/src/torchmetrics/image/scc.py @@ -0,0 +1,40 @@ +from typing import Any, Optional +from typing_extensions import Literal + +import torch +from torch import Tensor, tensor + +from torchmetrics.functional.image.scc import _scc_per_channel_compute as _scc_compute, _scc_update +from torchmetrics.metric import Metric + +class SpatialCorrelationCoefficient(Metric): + is_differentiable = True + higher_is_better = True + full_state_update = False + + scc_score: Tensor + total: Tensor + + def __init__(self, high_pass_filter: Tensor = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]), + window_size: int = 11, **kwargs: Any) -> None: + super().__init__(**kwargs) + + self.hp_filter = high_pass_filter + self.ws = window_size + + self.add_state("scc_score", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + preds, target, hp_filter = _scc_update(preds, target, self.hp_filter, self.ws) + scc_per_channel = [ + _scc_compute(preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, self.ws) + for i in range(preds.size(1)) + ] + self.scc_score += torch.sum(torch.mean(torch.cat(scc_per_channel, dim=1), dim=[1,2,3])) + self.total += preds.size(0) + + def compute(self) -> Tensor: + """Compute the VIF score based on inputs passed in to ``update`` previously.""" + return self.scc_score / self.total \ No newline at end of file From be3da8724f5b67d4fadc723013871a8654a7986e Mon Sep 17 00:00:00 2001 From: HoseinAkbarzadeh Date: Wed, 29 Nov 2023 13:57:01 +0000 Subject: [PATCH 02/21] tests for spatial correlation coefficient (scc) added. --- tests/unittests/image/test_scc.py | 60 +++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tests/unittests/image/test_scc.py diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py new file mode 100644 index 00000000000..322a58b5bfc --- /dev/null +++ b/tests/unittests/image/test_scc.py @@ -0,0 +1,60 @@ +from collections import namedtuple + +import numpy as np +import pytest +import torch + +from torchmetrics.functional.image import spatial_correlation_coefficient +from torchmetrics.image import SpatialCorrelationCoefficient +from unittests import BATCH_SIZE, NUM_BATCHES +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +from sewar.full_ref import scc as sewar_scc + +seed_all(42) + +Input = namedtuple('Input', ["preds", "target"]) +_inputs = [ + Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128), + target=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128) + ) + for channels in [1, 3] +] + +def _reference_scc(preds, target): + """Reference implementation of scc from sewar""" + preds = torch.movedim(preds, 1, -1) + target = torch.movedim(target, 1, -1) + preds = preds.cpu().numpy() + target = target.cpu().numpy() + hp_filter = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) + window_size = 8 + scc = [sewar_scc(GT=target[batch], P=preds[batch], win=hp_filter, ws=window_size) for batch in range(preds.shape[0])] + return np.mean(scc) + +@pytest.mark.parametrize( + "preds, target", + [(i.preds, i.target) for i in _inputs] +) +class TestSpatialCorrelationCoefficient(MetricTester): + atol = 1e-3 + + @pytest.mark.parametrize("ddp", [True, False]) + def test_scc(self, preds, target, ddp): + self.run_class_metric_test( + ddp, + preds, + target, + metric_class=SpatialCorrelationCoefficient, + reference_metric=_reference_scc + ) + + def test_scc_functional(self, preds, target): + self.run_functional_metric_test( + preds, + target, + metric_functional=spatial_correlation_coefficient, + reference_metric=_reference_scc, + ) \ No newline at end of file From d678f2aee84a31459595f80a975bf14d1aa7ad03 Mon Sep 17 00:00:00 2001 From: HoseinAkbarzadeh Date: Wed, 29 Nov 2023 14:32:51 +0000 Subject: [PATCH 03/21] documentation for spatial correlation coefficient (scc) updated. --- .../image/spatial_correlation_coefficient.rst | 21 ++++++++ docs/source/links.rst | 1 + src/torchmetrics/functional/image/scc.py | 49 ++++++++++++++++++- src/torchmetrics/image/scc.py | 31 +++++++++++- 4 files changed, 98 insertions(+), 4 deletions(-) create mode 100644 docs/source/image/spatial_correlation_coefficient.rst diff --git a/docs/source/image/spatial_correlation_coefficient.rst b/docs/source/image/spatial_correlation_coefficient.rst new file mode 100644 index 00000000000..d08e2c4a99e --- /dev/null +++ b/docs/source/image/spatial_correlation_coefficient.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Spatial Correlation Coefficient (SCC) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg + :tags: Image + +.. include:: ../links.rst + +################################# +Spatial Correlation Coefficient (SCC) +################################# + +Module Interface +________________ + +.. autoclass:: torchmetrics.image.SpatialCorrelationCoefficient + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.image.spatial_correlation_coefficient diff --git a/docs/source/links.rst b/docs/source/links.rst index 4ee095dd659..9ff27f4c95a 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -167,3 +167,4 @@ .. _FLORES-101: https://arxiv.org/abs/2106.03193 .. _FLORES-200: https://arxiv.org/abs/2207.04672 .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html +.. _SCC: https://www.tandfonline.com/doi/abs/10.1080/014311698215973 diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index 6c40df40047..bf82d2665fa 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -1,6 +1,5 @@ -from typing import Union, Tuple, Optional -from typing_extensions import Literal +from typing import Union, Tuple import torch from torch import Tensor, tensor @@ -17,6 +16,17 @@ def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: i target: Ground truth tensor hp_filter: High-pass filter tensor window_size: Local window size integer + + Return: + Tuple of (preds, target, hp_filter) tensors + + Raises: + ValueError: + If ``preds`` and ``target`` have different number of channels + If ``preds`` and ``target`` have different shapes + If ``preds`` and ``target`` have invalid shapes + If ``window_size`` is not a positive integer + If ``window_size`` is greater than the size of the image """ if preds.dtype != target.dtype: target = target.to(preds.dtype) @@ -89,9 +99,22 @@ def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor): return preds_var, target_var, target_preds_cov def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int): + """Computes per channel Spatial Correlation Coefficient. + + Args: + preds: estimated image of Bx1xHxW shape. + target: ground truth image of Bx1xHxW shape. + hp_filter: 2D high-pass filter. + window_size: size of window for local mean calculation. + Return: + Tensor with Spatial Correlation Coefficient score""" + dtype = preds.dtype device = preds.device + # This code is inspired by + # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187. + window = torch.ones(size=(1, 1, window_size, window_size), dtype=dtype, device=device)/( window_size**2 ) preds_hp = _hp_2d_laplacian(preds, hp_filter) @@ -112,6 +135,28 @@ def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, w def spatial_correlation_coefficient(preds: Tensor, target: Tensor, hp_filter: Tensor = tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]), window_size: int = 8): + """Compute Spatial Correlation Coefficient (SCC_). + + Args: + preds: predicted images of shape ``(N,C,H,W)`` or ``(N,H,W)``. + target: ground truth images of shape ``(N,C,H,W)`` or ``(N,H,W)``. + hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]) + window_size: Local window size integer. default: 8 + + Return: + Tensor with scc score + + Example: + >>> import torch + >>> from torchmetrics.functional.image import spatial_correlation_coefficient as scc + >>> _ = torch.manual_seed(42) + >>> x = torch.randn(5, 3, 16, 16) + >>> scc(x, x) + tensor(1.) + >>> x = torch.randn(5, 16, 16) + >>> scc(x, x) + tensor(1.) + """ preds, target, hp_filter = _scc_update(preds, target, hp_filter, window_size) per_channel = [_scc_per_channel_compute(preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, window_size) for i in range(preds.size(1))] diff --git a/src/torchmetrics/image/scc.py b/src/torchmetrics/image/scc.py index 7fe6e5d0754..1f7366c4f14 100644 --- a/src/torchmetrics/image/scc.py +++ b/src/torchmetrics/image/scc.py @@ -1,5 +1,4 @@ -from typing import Any, Optional -from typing_extensions import Literal +from typing import Any import torch from torch import Tensor, tensor @@ -8,6 +7,34 @@ from torchmetrics.metric import Metric class SpatialCorrelationCoefficient(Metric): + """Compute Spatial Correlation Coefficient (SCC_). + + As input to ``forward`` and ``update`` the metric accepts the following input + + - ``preds`` (:class:`~torch.Tensor`): Predictions from model of shape ``(N,C,H,W)`` or ``(N,H,W)``. + - ``target`` (:class:`~torch.Tensor`): Ground truth values of shape ``(N,C,H,W)`` or ``(N,H,W)``. + + As output of `forward` and `compute` the metric returns the following output + + - ``scc`` (:class:`~torch.Tensor`): Tensor with scc score + + Args: + hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]). + window_size: Local window size integer. default: 8. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.image import SpatialCorrelationCoefficient as SCC + >>> preds = torch.randn([32, 3, 64, 64]) + >>> target = torch.randn([32, 3, 64, 64]) + >>> scc = SCC() + >>> scc(preds, target) + tensor(0.0022) + + """ + is_differentiable = True higher_is_better = True full_state_update = False From 952233bf0b1dd0a3664a4223b8861a1a1d73a9ef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Nov 2023 14:52:53 +0000 Subject: [PATCH 04/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/image/__init__.py | 2 +- src/torchmetrics/functional/image/scc.py | 86 +++++++++++-------- src/torchmetrics/image/__init__.py | 2 +- src/torchmetrics/image/scc.py | 26 +++--- tests/unittests/image/test_scc.py | 30 +++---- 5 files changed, 81 insertions(+), 65 deletions(-) diff --git a/src/torchmetrics/functional/image/__init__.py b/src/torchmetrics/functional/image/__init__.py index 41c4090a542..0e713683b28 100644 --- a/src/torchmetrics/functional/image/__init__.py +++ b/src/torchmetrics/functional/image/__init__.py @@ -46,5 +46,5 @@ "visual_information_fidelity", "learned_perceptual_image_patch_similarity", "perceptual_path_length", - "spatial_correlation_coefficient" + "spatial_correlation_coefficient", ] diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index bf82d2665fa..b5a87ee1f8d 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -1,5 +1,4 @@ - -from typing import Union, Tuple +from typing import Tuple, Union import torch from torch import Tensor, tensor @@ -8,6 +7,7 @@ from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.distributed import reduce + def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tuple[Tensor, Tensor, Tensor]: """Update and returns variables required to compute Spatial Correlation Coefficient. @@ -27,6 +27,7 @@ def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: i If ``preds`` and ``target`` have invalid shapes If ``window_size`` is not a positive integer If ``window_size`` is greater than the size of the image + """ if preds.dtype != target.dtype: target = target.to(preds.dtype) @@ -37,67 +38,68 @@ def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: i " or batch of grayscale images of BxHxW shape." f" Got preds: {preds.shape} and target: {target.shape}." ) - + if len(preds.shape) == 3: preds = preds.unsqueeze(1) target = target.unsqueeze(1) if not window_size > 0: - raise ValueError( - f"Expected `window_size` to be a positive integer. Got {window_size}." - ) + raise ValueError(f"Expected `window_size` to be a positive integer. Got {window_size}.") if window_size > preds.size(2) or window_size > preds.size(3): raise ValueError( f"Expected `window_size` to be less than or equal to the size of the image." f" Got window_size: {window_size} and image size: {preds.size(2)}x{preds.size(3)}." ) - + preds = preds.to(torch.float32) target = target.to(torch.float32) hp_filter = hp_filter[None, None, :].to(dtype=preds.dtype, device=preds.device) return preds, target, hp_filter + def _symmetric_reflect_pad_2d(input: Tensor, pad: Union[int, Tuple[int, ...]]) -> Tensor: if isinstance(pad, int): pad = (pad, pad, pad, pad) assert len(pad) == 4 - left_pad = input[:, :, :, 0:pad[0]].flip(dims=[3]) - right_pad = input[:, :, :, -pad[1]:].flip(dims=[3]) + left_pad = input[:, :, :, 0 : pad[0]].flip(dims=[3]) + right_pad = input[:, :, :, -pad[1] :].flip(dims=[3]) padded = torch.cat([left_pad, input, right_pad], dim=3) - top_pad = padded[:, :, 0:pad[2], :].flip(dims=[2]) - bottom_pad = padded[:, :, -pad[3]:, :].flip(dims=[2]) - padded = torch.cat([top_pad, padded, bottom_pad], dim=2) - return padded + top_pad = padded[:, :, 0 : pad[2], :].flip(dims=[2]) + bottom_pad = padded[:, :, -pad[3] :, :].flip(dims=[2]) + return torch.cat([top_pad, padded, bottom_pad], dim=2) + def _signal_convolve_2d(input: Tensor, kernel: Tensor) -> Tensor: - left_padding = int(torch.floor(tensor((kernel.size(3)-1)/2)).item()) - right_padding = int(torch.ceil(tensor((kernel.size(3)-1)/2)).item()) - top_padding = int(torch.floor(tensor((kernel.size(2)-1)/2)).item()) - bottom_padding = int(torch.ceil(tensor((kernel.size(2)-1)/2)).item()) + left_padding = int(torch.floor(tensor((kernel.size(3) - 1) / 2)).item()) + right_padding = int(torch.ceil(tensor((kernel.size(3) - 1) / 2)).item()) + top_padding = int(torch.floor(tensor((kernel.size(2) - 1) / 2)).item()) + bottom_padding = int(torch.ceil(tensor((kernel.size(2) - 1) / 2)).item()) padded = _symmetric_reflect_pad_2d(input, pad=(left_padding, right_padding, top_padding, bottom_padding)) kernel = kernel.flip([2, 3]) - out = conv2d(padded, kernel, stride=1, padding=0) - return out + return conv2d(padded, kernel, stride=1, padding=0) + def _hp_2d_laplacian(input: Tensor, kernel: Tensor) -> Tensor: output = _signal_convolve_2d(input, kernel) output += _signal_convolve_2d(input, kernel) return output + def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor): - preds_mean = conv2d(preds, window, stride=1, padding='same') - target_mean = conv2d(target, window, stride=1, padding='same') + preds_mean = conv2d(preds, window, stride=1, padding="same") + target_mean = conv2d(target, window, stride=1, padding="same") - preds_var = conv2d(preds**2, window, stride=1, padding='same') - preds_mean**2 - target_var = conv2d(target**2, window, stride=1, padding='same') - target_mean**2 - target_preds_cov = conv2d(target*preds, window, stride=1, padding='same') - target_mean*preds_mean + preds_var = conv2d(preds**2, window, stride=1, padding="same") - preds_mean**2 + target_var = conv2d(target**2, window, stride=1, padding="same") - target_mean**2 + target_preds_cov = conv2d(target * preds, window, stride=1, padding="same") - target_mean * preds_mean return preds_var, target_var, target_preds_cov + def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int): """Computes per channel Spatial Correlation Coefficient. @@ -106,35 +108,41 @@ def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, w target: ground truth image of Bx1xHxW shape. hp_filter: 2D high-pass filter. window_size: size of window for local mean calculation. + Return: - Tensor with Spatial Correlation Coefficient score""" - + Tensor with Spatial Correlation Coefficient score + + """ dtype = preds.dtype device = preds.device # This code is inspired by # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187. - window = torch.ones(size=(1, 1, window_size, window_size), dtype=dtype, device=device)/( window_size**2 ) - + window = torch.ones(size=(1, 1, window_size, window_size), dtype=dtype, device=device) / (window_size**2) + preds_hp = _hp_2d_laplacian(preds, hp_filter) target_hp = _hp_2d_laplacian(target, hp_filter) preds_var, target_var, target_preds_cov = _local_variance_covariance(preds_hp, target_hp, window) - preds_var[preds_var<0] = 0 - target_var[target_var<0] = 0 + preds_var[preds_var < 0] = 0 + target_var[target_var < 0] = 0 den = torch.sqrt(target_var) * torch.sqrt(preds_var) - idx = (den==0) + idx = den == 0 den[den == 0] = 1 scc = target_preds_cov / den scc[idx] = 0 return scc -def spatial_correlation_coefficient(preds: Tensor, target: Tensor, - hp_filter: Tensor = tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]), - window_size: int = 8): + +def spatial_correlation_coefficient( + preds: Tensor, + target: Tensor, + hp_filter: Tensor = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]), + window_size: int = 8, +): """Compute Spatial Correlation Coefficient (SCC_). Args: @@ -156,8 +164,14 @@ def spatial_correlation_coefficient(preds: Tensor, target: Tensor, >>> x = torch.randn(5, 16, 16) >>> scc(x, x) tensor(1.) + """ preds, target, hp_filter = _scc_update(preds, target, hp_filter, window_size) - per_channel = [_scc_per_channel_compute(preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, window_size) for i in range(preds.size(1))] - return reduce(torch.cat(per_channel, dim=1), reduction='elementwise_mean') + per_channel = [ + _scc_per_channel_compute( + preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, window_size + ) + for i in range(preds.size(1)) + ] + return reduce(torch.cat(per_channel, dim=1), reduction="elementwise_mean") diff --git a/src/torchmetrics/image/__init__.py b/src/torchmetrics/image/__init__.py index 623fb32af23..503ed7ad82a 100644 --- a/src/torchmetrics/image/__init__.py +++ b/src/torchmetrics/image/__init__.py @@ -19,11 +19,11 @@ from torchmetrics.image.rase import RelativeAverageSpectralError from torchmetrics.image.rmse_sw import RootMeanSquaredErrorUsingSlidingWindow from torchmetrics.image.sam import SpectralAngleMapper +from torchmetrics.image.scc import SpatialCorrelationCoefficient from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure from torchmetrics.image.tv import TotalVariation from torchmetrics.image.uqi import UniversalImageQualityIndex from torchmetrics.image.vif import VisualInformationFidelity -from torchmetrics.image.scc import SpatialCorrelationCoefficient from torchmetrics.utilities.imports import ( _TORCH_FIDELITY_AVAILABLE, _TORCHVISION_AVAILABLE, diff --git a/src/torchmetrics/image/scc.py b/src/torchmetrics/image/scc.py index 1f7366c4f14..64a660009b9 100644 --- a/src/torchmetrics/image/scc.py +++ b/src/torchmetrics/image/scc.py @@ -3,9 +3,11 @@ import torch from torch import Tensor, tensor -from torchmetrics.functional.image.scc import _scc_per_channel_compute as _scc_compute, _scc_update +from torchmetrics.functional.image.scc import _scc_per_channel_compute as _scc_compute +from torchmetrics.functional.image.scc import _scc_update from torchmetrics.metric import Metric + class SpatialCorrelationCoefficient(Metric): """Compute Spatial Correlation Coefficient (SCC_). @@ -42,26 +44,30 @@ class SpatialCorrelationCoefficient(Metric): scc_score: Tensor total: Tensor - def __init__(self, high_pass_filter: Tensor = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]), - window_size: int = 11, **kwargs: Any) -> None: + def __init__( + self, + high_pass_filter: Tensor = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]), + window_size: int = 11, + **kwargs: Any + ) -> None: super().__init__(**kwargs) - + self.hp_filter = high_pass_filter self.ws = window_size - + self.add_state("scc_score", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum") - + def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" preds, target, hp_filter = _scc_update(preds, target, self.hp_filter, self.ws) scc_per_channel = [ _scc_compute(preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, self.ws) for i in range(preds.size(1)) - ] - self.scc_score += torch.sum(torch.mean(torch.cat(scc_per_channel, dim=1), dim=[1,2,3])) + ] + self.scc_score += torch.sum(torch.mean(torch.cat(scc_per_channel, dim=1), dim=[1, 2, 3])) self.total += preds.size(0) - + def compute(self) -> Tensor: """Compute the VIF score based on inputs passed in to ``update`` previously.""" - return self.scc_score / self.total \ No newline at end of file + return self.scc_score / self.total diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py index 322a58b5bfc..61e065b844c 100644 --- a/tests/unittests/image/test_scc.py +++ b/tests/unittests/image/test_scc.py @@ -3,52 +3,48 @@ import numpy as np import pytest import torch - +from sewar.full_ref import scc as sewar_scc from torchmetrics.functional.image import spatial_correlation_coefficient from torchmetrics.image import SpatialCorrelationCoefficient + from unittests import BATCH_SIZE, NUM_BATCHES from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester -from sewar.full_ref import scc as sewar_scc - seed_all(42) -Input = namedtuple('Input', ["preds", "target"]) +Input = namedtuple("Input", ["preds", "target"]) _inputs = [ Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128), - target=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128) + target=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128), ) for channels in [1, 3] ] + def _reference_scc(preds, target): - """Reference implementation of scc from sewar""" + """Reference implementation of scc from sewar.""" preds = torch.movedim(preds, 1, -1) target = torch.movedim(target, 1, -1) preds = preds.cpu().numpy() target = target.cpu().numpy() hp_filter = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) window_size = 8 - scc = [sewar_scc(GT=target[batch], P=preds[batch], win=hp_filter, ws=window_size) for batch in range(preds.shape[0])] + scc = [ + sewar_scc(GT=target[batch], P=preds[batch], win=hp_filter, ws=window_size) for batch in range(preds.shape[0]) + ] return np.mean(scc) -@pytest.mark.parametrize( - "preds, target", - [(i.preds, i.target) for i in _inputs] -) + +@pytest.mark.parametrize("preds, target", [(i.preds, i.target) for i in _inputs]) class TestSpatialCorrelationCoefficient(MetricTester): atol = 1e-3 @pytest.mark.parametrize("ddp", [True, False]) def test_scc(self, preds, target, ddp): self.run_class_metric_test( - ddp, - preds, - target, - metric_class=SpatialCorrelationCoefficient, - reference_metric=_reference_scc + ddp, preds, target, metric_class=SpatialCorrelationCoefficient, reference_metric=_reference_scc ) def test_scc_functional(self, preds, target): @@ -57,4 +53,4 @@ def test_scc_functional(self, preds, target): target, metric_functional=spatial_correlation_coefficient, reference_metric=_reference_scc, - ) \ No newline at end of file + ) From 48164a4d0eac395cab0fa6f02809e81377f7b003 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 30 Nov 2023 13:17:34 +0100 Subject: [PATCH 05/21] Apply suggestions from code review Co-authored-by: Nicki Skafte Detlefsen --- .../image/spatial_correlation_coefficient.rst | 4 +-- src/torchmetrics/functional/image/scc.py | 32 +++++++++++++------ src/torchmetrics/image/scc.py | 13 ++++++++ tests/unittests/image/test_scc.py | 13 ++++++++ 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/docs/source/image/spatial_correlation_coefficient.rst b/docs/source/image/spatial_correlation_coefficient.rst index d08e2c4a99e..02ed96fd107 100644 --- a/docs/source/image/spatial_correlation_coefficient.rst +++ b/docs/source/image/spatial_correlation_coefficient.rst @@ -5,9 +5,9 @@ .. include:: ../links.rst -################################# +##################################### Spatial Correlation Coefficient (SCC) -################################# +##################################### Module Interface ________________ diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index b5a87ee1f8d..cdc48ed9cf1 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Tuple, Union import torch @@ -32,7 +45,7 @@ def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: i if preds.dtype != target.dtype: target = target.to(preds.dtype) _check_same_shape(preds, target) - if len(preds.shape) not in (3, 4): + if preds.ndim not in (3, 4): raise ValueError( "Expected `preds` and `target` to have batch of colored images with BxCxHxW shape" " or batch of grayscale images of BxHxW shape." @@ -61,7 +74,8 @@ def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: i def _symmetric_reflect_pad_2d(input: Tensor, pad: Union[int, Tuple[int, ...]]) -> Tensor: if isinstance(pad, int): pad = (pad, pad, pad, pad) - assert len(pad) == 4 + if len(pad) != 4: + raise ValueError(f"Expected padding to have length 4, but got {len(pad)}") left_pad = input[:, :, :, 0 : pad[0]].flip(dims=[3]) right_pad = input[:, :, :, -pad[1] :].flip(dims=[3]) @@ -73,10 +87,10 @@ def _symmetric_reflect_pad_2d(input: Tensor, pad: Union[int, Tuple[int, ...]]) - def _signal_convolve_2d(input: Tensor, kernel: Tensor) -> Tensor: - left_padding = int(torch.floor(tensor((kernel.size(3) - 1) / 2)).item()) - right_padding = int(torch.ceil(tensor((kernel.size(3) - 1) / 2)).item()) - top_padding = int(torch.floor(tensor((kernel.size(2) - 1) / 2)).item()) - bottom_padding = int(torch.ceil(tensor((kernel.size(2) - 1) / 2)).item()) + left_padding = int(math.floor((kernel.size(3) - 1) / 2)) + right_padding = int(math.ceil((kernel.size(3) - 1) / 2)) + top_padding = int(math.floor((kernel.size(2) - 1) / 2)) + bottom_padding = int(math.ceil((kernel.size(2) - 1) / 2)) padded = _symmetric_reflect_pad_2d(input, pad=(left_padding, right_padding, top_padding, bottom_padding)) kernel = kernel.flip([2, 3]) @@ -89,7 +103,7 @@ def _hp_2d_laplacian(input: Tensor, kernel: Tensor) -> Tensor: return output -def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor): +def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> Tuple[Tensor, Tensor, Tensor]: preds_mean = conv2d(preds, window, stride=1, padding="same") target_mean = conv2d(target, window, stride=1, padding="same") @@ -100,7 +114,7 @@ def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor): return preds_var, target_var, target_preds_cov -def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int): +def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tensor: """Computes per channel Spatial Correlation Coefficient. Args: @@ -142,7 +156,7 @@ def spatial_correlation_coefficient( target: Tensor, hp_filter: Tensor = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]), window_size: int = 8, -): +) -> Tensor: """Compute Spatial Correlation Coefficient (SCC_). Args: diff --git a/src/torchmetrics/image/scc.py b/src/torchmetrics/image/scc.py index 64a660009b9..701474c307c 100644 --- a/src/torchmetrics/image/scc.py +++ b/src/torchmetrics/image/scc.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any import torch diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py index 61e065b844c..d54fb2e4c87 100644 --- a/tests/unittests/image/test_scc.py +++ b/tests/unittests/image/test_scc.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from collections import namedtuple import numpy as np From 2637b2396b5fa57fa415c3538162095bd7336fb9 Mon Sep 17 00:00:00 2001 From: HoseinAkbarzadeh Date: Thu, 30 Nov 2023 13:21:59 +0000 Subject: [PATCH 06/21] scc functional docstrings added. _hp_2d_laplacian function updated --- src/torchmetrics/functional/image/scc.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index cdc48ed9cf1..1b2e7205034 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Tuple, Union import torch @@ -72,6 +73,7 @@ def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: i def _symmetric_reflect_pad_2d(input: Tensor, pad: Union[int, Tuple[int, ...]]) -> Tensor: + """Applies symmetric padding to the 2D image tensor input using `reflect` mode (d c b a | a b c d | d c b a).""" if isinstance(pad, int): pad = (pad, pad, pad, pad) if len(pad) != 4: @@ -87,6 +89,7 @@ def _symmetric_reflect_pad_2d(input: Tensor, pad: Union[int, Tuple[int, ...]]) - def _signal_convolve_2d(input: Tensor, kernel: Tensor) -> Tensor: + """Applies 2D signal convolution to the input tensor with the given kernel.""" left_padding = int(math.floor((kernel.size(3) - 1) / 2)) right_padding = int(math.ceil((kernel.size(3) - 1) / 2)) top_padding = int(math.floor((kernel.size(2) - 1) / 2)) @@ -98,12 +101,17 @@ def _signal_convolve_2d(input: Tensor, kernel: Tensor) -> Tensor: def _hp_2d_laplacian(input: Tensor, kernel: Tensor) -> Tensor: - output = _signal_convolve_2d(input, kernel) - output += _signal_convolve_2d(input, kernel) + """Applies 2-D Laplace filter to the input tensor with the given high pass filter.""" + output = _signal_convolve_2d(input, kernel) * 2. return output def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """Computes local variance and covariance of the input tensors.""" + + # This code is inspired by + # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187. + preds_mean = conv2d(preds, window, stride=1, padding="same") target_mean = conv2d(target, window, stride=1, padding="same") From 4a0058af35efbf89708665caabd83646a8581dc2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Nov 2023 13:23:31 +0000 Subject: [PATCH 07/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/image/scc.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index 1b2e7205034..86600ed52f4 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -102,13 +102,11 @@ def _signal_convolve_2d(input: Tensor, kernel: Tensor) -> Tensor: def _hp_2d_laplacian(input: Tensor, kernel: Tensor) -> Tensor: """Applies 2-D Laplace filter to the input tensor with the given high pass filter.""" - output = _signal_convolve_2d(input, kernel) * 2. - return output + return _signal_convolve_2d(input, kernel) * 2.0 def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """Computes local variance and covariance of the input tensors.""" - # This code is inspired by # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187. From 06682361840c0c4564314d60e3a4e83aa0799591 Mon Sep 17 00:00:00 2001 From: HoseinAkbarzadeh Date: Thu, 30 Nov 2023 15:28:01 +0000 Subject: [PATCH 08/21] required failed checks resolved. --- src/torchmetrics/functional/image/scc.py | 20 +++++++++++--------- src/torchmetrics/image/scc.py | 7 +++++-- tests/unittests/image/test_scc.py | 9 +++++++-- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index 1b2e7205034..58238ca4dd2 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -72,37 +72,37 @@ def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: i return preds, target, hp_filter -def _symmetric_reflect_pad_2d(input: Tensor, pad: Union[int, Tuple[int, ...]]) -> Tensor: +def _symmetric_reflect_pad_2d(input_img: Tensor, pad: Union[int, Tuple[int, ...]]) -> Tensor: """Applies symmetric padding to the 2D image tensor input using `reflect` mode (d c b a | a b c d | d c b a).""" if isinstance(pad, int): pad = (pad, pad, pad, pad) if len(pad) != 4: raise ValueError(f"Expected padding to have length 4, but got {len(pad)}") - left_pad = input[:, :, :, 0 : pad[0]].flip(dims=[3]) - right_pad = input[:, :, :, -pad[1] :].flip(dims=[3]) - padded = torch.cat([left_pad, input, right_pad], dim=3) + left_pad = input_img[:, :, :, 0 : pad[0]].flip(dims=[3]) + right_pad = input_img[:, :, :, -pad[1] :].flip(dims=[3]) + padded = torch.cat([left_pad, input_img, right_pad], dim=3) top_pad = padded[:, :, 0 : pad[2], :].flip(dims=[2]) bottom_pad = padded[:, :, -pad[3] :, :].flip(dims=[2]) return torch.cat([top_pad, padded, bottom_pad], dim=2) -def _signal_convolve_2d(input: Tensor, kernel: Tensor) -> Tensor: +def _signal_convolve_2d(input_img: Tensor, kernel: Tensor) -> Tensor: """Applies 2D signal convolution to the input tensor with the given kernel.""" left_padding = int(math.floor((kernel.size(3) - 1) / 2)) right_padding = int(math.ceil((kernel.size(3) - 1) / 2)) top_padding = int(math.floor((kernel.size(2) - 1) / 2)) bottom_padding = int(math.ceil((kernel.size(2) - 1) / 2)) - padded = _symmetric_reflect_pad_2d(input, pad=(left_padding, right_padding, top_padding, bottom_padding)) + padded = _symmetric_reflect_pad_2d(input_img, pad=(left_padding, right_padding, top_padding, bottom_padding)) kernel = kernel.flip([2, 3]) return conv2d(padded, kernel, stride=1, padding=0) -def _hp_2d_laplacian(input: Tensor, kernel: Tensor) -> Tensor: +def _hp_2d_laplacian(input_img: Tensor, kernel: Tensor) -> Tensor: """Applies 2-D Laplace filter to the input tensor with the given high pass filter.""" - output = _signal_convolve_2d(input, kernel) * 2. + output = _signal_convolve_2d(input_img, kernel) * 2. return output @@ -162,7 +162,7 @@ def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, w def spatial_correlation_coefficient( preds: Tensor, target: Tensor, - hp_filter: Tensor = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]), + hp_filter: Tensor = None, window_size: int = 8, ) -> Tensor: """Compute Spatial Correlation Coefficient (SCC_). @@ -188,6 +188,8 @@ def spatial_correlation_coefficient( tensor(1.) """ + if hp_filter is None: + hp_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) preds, target, hp_filter = _scc_update(preds, target, hp_filter, window_size) per_channel = [ diff --git a/src/torchmetrics/image/scc.py b/src/torchmetrics/image/scc.py index 701474c307c..f406aa312fe 100644 --- a/src/torchmetrics/image/scc.py +++ b/src/torchmetrics/image/scc.py @@ -59,12 +59,15 @@ class SpatialCorrelationCoefficient(Metric): def __init__( self, - high_pass_filter: Tensor = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]), - window_size: int = 11, + high_pass_filter: Tensor = None, + window_size: int = 8, **kwargs: Any ) -> None: super().__init__(**kwargs) + if high_pass_filter is None: + high_pass_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) + self.hp_filter = high_pass_filter self.ws = window_size diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py index d54fb2e4c87..de96f6736f0 100644 --- a/tests/unittests/image/test_scc.py +++ b/tests/unittests/image/test_scc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple +from typing import NamedTuple import numpy as np import pytest @@ -26,7 +26,9 @@ seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) +class Input(NamedTuple): + preds: torch.Tensor + target: torch.Tensor _inputs = [ Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128), @@ -52,15 +54,18 @@ def _reference_scc(preds, target): @pytest.mark.parametrize("preds, target", [(i.preds, i.target) for i in _inputs]) class TestSpatialCorrelationCoefficient(MetricTester): + """Tests for SpatialCorrelationCoefficient metric.""" atol = 1e-3 @pytest.mark.parametrize("ddp", [True, False]) def test_scc(self, preds, target, ddp): + """Test SpatialCorrelationCoefficient class usage.""" self.run_class_metric_test( ddp, preds, target, metric_class=SpatialCorrelationCoefficient, reference_metric=_reference_scc ) def test_scc_functional(self, preds, target): + """Test SpatialCorrelationCoefficient functional usage.""" self.run_functional_metric_test( preds, target, From 24bfc7766639027b7d33a1713e234b37ec7334f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Nov 2023 15:31:43 +0000 Subject: [PATCH 09/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/image/scc.py | 7 +------ tests/unittests/image/test_scc.py | 4 ++++ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/image/scc.py b/src/torchmetrics/image/scc.py index f406aa312fe..185a19ebf96 100644 --- a/src/torchmetrics/image/scc.py +++ b/src/torchmetrics/image/scc.py @@ -57,12 +57,7 @@ class SpatialCorrelationCoefficient(Metric): scc_score: Tensor total: Tensor - def __init__( - self, - high_pass_filter: Tensor = None, - window_size: int = 8, - **kwargs: Any - ) -> None: + def __init__(self, high_pass_filter: Tensor = None, window_size: int = 8, **kwargs: Any) -> None: super().__init__(**kwargs) if high_pass_filter is None: diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py index de96f6736f0..2bb35b73dfe 100644 --- a/tests/unittests/image/test_scc.py +++ b/tests/unittests/image/test_scc.py @@ -26,9 +26,12 @@ seed_all(42) + class Input(NamedTuple): preds: torch.Tensor target: torch.Tensor + + _inputs = [ Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128), @@ -55,6 +58,7 @@ def _reference_scc(preds, target): @pytest.mark.parametrize("preds, target", [(i.preds, i.target) for i in _inputs]) class TestSpatialCorrelationCoefficient(MetricTester): """Tests for SpatialCorrelationCoefficient metric.""" + atol = 1e-3 @pytest.mark.parametrize("ddp", [True, False]) From 33b6d9073660b179159a8a5602c796f6137fbcfa Mon Sep 17 00:00:00 2001 From: HoseinAkbarzadeh Date: Thu, 30 Nov 2023 15:46:35 +0000 Subject: [PATCH 10/21] fixing the variable name mistake. --- src/torchmetrics/functional/image/scc.py | 2 +- tests/unittests/image/test_scc.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index 807c73f1655..93b0c8fd1db 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -102,7 +102,7 @@ def _signal_convolve_2d(input_img: Tensor, kernel: Tensor) -> Tensor: def _hp_2d_laplacian(input_img: Tensor, kernel: Tensor) -> Tensor: """Applies 2-D Laplace filter to the input tensor with the given high pass filter.""" - return _signal_convolve_2d(input, kernel) * 2.0 + return _signal_convolve_2d(input_img, kernel) * 2.0 def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> Tuple[Tensor, Tensor, Tensor]: diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py index de96f6736f0..2522a087e0a 100644 --- a/tests/unittests/image/test_scc.py +++ b/tests/unittests/image/test_scc.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import NamedTuple import numpy as np import pytest @@ -20,17 +19,14 @@ from torchmetrics.functional.image import spatial_correlation_coefficient from torchmetrics.image import SpatialCorrelationCoefficient -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -class Input(NamedTuple): - preds: torch.Tensor - target: torch.Tensor _inputs = [ - Input( + _Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128), target=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128), ) From f5a874f662050cac8893da77d80e34c5935203a7 Mon Sep 17 00:00:00 2001 From: HoseinAkbarzadeh Date: Mon, 4 Dec 2023 19:37:52 +0000 Subject: [PATCH 11/21] fixed even window size bug. changed atol to 1e-8. added None reduction in scc functional interface --- src/torchmetrics/functional/image/scc.py | 34 +++++++++++++++++------- tests/unittests/image/test_scc.py | 8 +++--- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index 93b0c8fd1db..59a28aa6105 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Tuple, Union +from typing import Tuple, Union, Optional +from typing_extensions import Literal import torch from torch import Tensor, tensor -from torch.nn.functional import conv2d +from torch.nn.functional import conv2d, pad from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.distributed import reduce @@ -110,12 +111,18 @@ def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> # This code is inspired by # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187. - preds_mean = conv2d(preds, window, stride=1, padding="same") - target_mean = conv2d(target, window, stride=1, padding="same") + left_padding = int(math.ceil((window.size(3) - 1) / 2)) + right_padding = int(math.floor((window.size(3) - 1) / 2)) - preds_var = conv2d(preds**2, window, stride=1, padding="same") - preds_mean**2 - target_var = conv2d(target**2, window, stride=1, padding="same") - target_mean**2 - target_preds_cov = conv2d(target * preds, window, stride=1, padding="same") - target_mean * preds_mean + preds = pad(preds, (left_padding, right_padding, left_padding, right_padding)) + target = pad(target, (left_padding, right_padding, left_padding, right_padding)) + + preds_mean = conv2d(preds, window, stride=1, padding=0) + target_mean = conv2d(target, window, stride=1, padding=0) + + preds_var = conv2d(preds**2, window, stride=1, padding=0) - preds_mean**2 + target_var = conv2d(target**2, window, stride=1, padding=0) - target_mean**2 + target_preds_cov = conv2d(target * preds, window, stride=1, padding=0) - target_mean * preds_mean return preds_var, target_var, target_preds_cov @@ -162,6 +169,7 @@ def spatial_correlation_coefficient( target: Tensor, hp_filter: Tensor = None, window_size: int = 8, + reduction: Optional[Literal["mean", "none", None]] = "mean", ) -> Tensor: """Compute Spatial Correlation Coefficient (SCC_). @@ -169,7 +177,8 @@ def spatial_correlation_coefficient( preds: predicted images of shape ``(N,C,H,W)`` or ``(N,H,W)``. target: ground truth images of shape ``(N,C,H,W)`` or ``(N,H,W)``. hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]) - window_size: Local window size integer. default: 8 + window_size: Local window size integer. default: 8, + reduction: Reduction method for output tensor. If ``None`` or ``"none"``, returns a tensor with the per sample results. default: ``"mean"``. Return: Tensor with scc score @@ -188,6 +197,10 @@ def spatial_correlation_coefficient( """ if hp_filter is None: hp_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) + if reduce is None: + reduction = "none" + if reduction not in ("mean", "none"): + raise ValueError(f"Expected reduction to be 'mean' or 'none', but got {reduction}") preds, target, hp_filter = _scc_update(preds, target, hp_filter, window_size) per_channel = [ @@ -196,4 +209,7 @@ def spatial_correlation_coefficient( ) for i in range(preds.size(1)) ] - return reduce(torch.cat(per_channel, dim=1), reduction="elementwise_mean") + if reduction == "none": + return torch.mean(torch.cat(per_channel, dim=1), dim=[1, 2, 3]) + if reduction == "mean": + return reduce(torch.cat(per_channel, dim=1), reduction="elementwise_mean") diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py index 198c110c46f..6ff6fbb5657 100644 --- a/tests/unittests/image/test_scc.py +++ b/tests/unittests/image/test_scc.py @@ -27,8 +27,8 @@ _inputs = [ _Input( - preds=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128), - target=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128), + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 32, 32), + target=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 32, 32), ) for channels in [1, 3] ] @@ -52,7 +52,7 @@ def _reference_scc(preds, target): class TestSpatialCorrelationCoefficient(MetricTester): """Tests for SpatialCorrelationCoefficient metric.""" - atol = 1e-3 + atol = 1e-8 @pytest.mark.parametrize("ddp", [True, False]) def test_scc(self, preds, target, ddp): @@ -67,5 +67,5 @@ def test_scc_functional(self, preds, target): preds, target, metric_functional=spatial_correlation_coefficient, - reference_metric=_reference_scc, + reference_metric=_reference_scc ) From a9c25eea0bcdc556cf34beb6b660b4e18530c60a Mon Sep 17 00:00:00 2001 From: HoseinAkbarzadeh Date: Mon, 4 Dec 2023 20:05:33 +0000 Subject: [PATCH 12/21] added new tests for scc functional interface' --- tests/unittests/image/test_scc.py | 32 +++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py index 6ff6fbb5657..6a988740b45 100644 --- a/tests/unittests/image/test_scc.py +++ b/tests/unittests/image/test_scc.py @@ -32,6 +32,10 @@ ) for channels in [1, 3] ] +_kernels = [ + torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]), + torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]]), +] def _reference_scc(preds, target): @@ -47,6 +51,22 @@ def _reference_scc(preds, target): ] return np.mean(scc) +def _wrapped_reference_scc(win, ws, reduction): + """Wrapper around reference implementation of scc from sewar.""" + def _wrapped(preds, target): + preds = torch.movedim(preds, 1, -1) + target = torch.movedim(target, 1, -1) + preds = preds.cpu().numpy() + target = target.cpu().numpy() + scc = [ + sewar_scc(GT=target[batch], P=preds[batch], win=win, ws=ws) for batch in range(preds.shape[0]) + ] + if reduction == 'mean': + return np.mean(scc) + if reduction == 'none': + return scc + return _wrapped + @pytest.mark.parametrize("preds, target", [(i.preds, i.target) for i in _inputs]) class TestSpatialCorrelationCoefficient(MetricTester): @@ -61,11 +81,19 @@ def test_scc(self, preds, target, ddp): ddp, preds, target, metric_class=SpatialCorrelationCoefficient, reference_metric=_reference_scc ) - def test_scc_functional(self, preds, target): + @pytest.mark.parametrize("hp_filter", _kernels) + @pytest.mark.parametrize("window_size", [8, 11]) + @pytest.mark.parametrize("reduction", ['mean', 'none']) + def test_scc_functional(self, preds, target, hp_filter, window_size, reduction): """Test SpatialCorrelationCoefficient functional usage.""" self.run_functional_metric_test( preds, target, metric_functional=spatial_correlation_coefficient, - reference_metric=_reference_scc + reference_metric=_wrapped_reference_scc(hp_filter, window_size, reduction), + metric_args={ + "hp_filter": hp_filter, + "window_size": window_size, + "reduction": reduction, + } ) From 2d9300b96001f3c50ce1dcb00ef4b7a8ecff6051 Mon Sep 17 00:00:00 2001 From: HoseinAkbarzadeh Date: Mon, 4 Dec 2023 20:22:07 +0000 Subject: [PATCH 13/21] resolving failed mypy checks. --- src/torchmetrics/functional/image/scc.py | 2 +- src/torchmetrics/image/scc.py | 4 ++-- tests/unittests/image/test_scc.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index 59a28aa6105..478ce3a4b44 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -167,7 +167,7 @@ def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, w def spatial_correlation_coefficient( preds: Tensor, target: Tensor, - hp_filter: Tensor = None, + hp_filter: Optional[Tensor] = None, window_size: int = 8, reduction: Optional[Literal["mean", "none", None]] = "mean", ) -> Tensor: diff --git a/src/torchmetrics/image/scc.py b/src/torchmetrics/image/scc.py index 185a19ebf96..4631b9434ed 100644 --- a/src/torchmetrics/image/scc.py +++ b/src/torchmetrics/image/scc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import torch from torch import Tensor, tensor @@ -57,7 +57,7 @@ class SpatialCorrelationCoefficient(Metric): scc_score: Tensor total: Tensor - def __init__(self, high_pass_filter: Tensor = None, window_size: int = 8, **kwargs: Any) -> None: + def __init__(self, high_pass_filter: Optional[Tensor] = None, window_size: int = 8, **kwargs: Any) -> None: super().__init__(**kwargs) if high_pass_filter is None: diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py index 6a988740b45..05714d6c4c5 100644 --- a/tests/unittests/image/test_scc.py +++ b/tests/unittests/image/test_scc.py @@ -33,8 +33,7 @@ for channels in [1, 3] ] _kernels = [ - torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]), - torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]]), + torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) ] From 58b39352e06fdead99e1b387bd9289e0fb56da9f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Dec 2023 20:24:03 +0000 Subject: [PATCH 14/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/image/scc.py | 5 +++-- tests/unittests/image/test_scc.py | 20 ++++++++++---------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index 478ce3a4b44..876332b87cd 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Tuple, Union, Optional -from typing_extensions import Literal +from typing import Optional, Tuple, Union import torch from torch import Tensor, tensor from torch.nn.functional import conv2d, pad +from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.distributed import reduce @@ -213,3 +213,4 @@ def spatial_correlation_coefficient( return torch.mean(torch.cat(per_channel, dim=1), dim=[1, 2, 3]) if reduction == "mean": return reduce(torch.cat(per_channel, dim=1), reduction="elementwise_mean") + return None diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py index 05714d6c4c5..6ef1443371c 100644 --- a/tests/unittests/image/test_scc.py +++ b/tests/unittests/image/test_scc.py @@ -32,9 +32,7 @@ ) for channels in [1, 3] ] -_kernels = [ - torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) -] +_kernels = [torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])] def _reference_scc(preds, target): @@ -50,20 +48,22 @@ def _reference_scc(preds, target): ] return np.mean(scc) + def _wrapped_reference_scc(win, ws, reduction): """Wrapper around reference implementation of scc from sewar.""" + def _wrapped(preds, target): preds = torch.movedim(preds, 1, -1) target = torch.movedim(target, 1, -1) preds = preds.cpu().numpy() target = target.cpu().numpy() - scc = [ - sewar_scc(GT=target[batch], P=preds[batch], win=win, ws=ws) for batch in range(preds.shape[0]) - ] - if reduction == 'mean': + scc = [sewar_scc(GT=target[batch], P=preds[batch], win=win, ws=ws) for batch in range(preds.shape[0])] + if reduction == "mean": return np.mean(scc) - if reduction == 'none': + if reduction == "none": return scc + return None + return _wrapped @@ -82,7 +82,7 @@ def test_scc(self, preds, target, ddp): @pytest.mark.parametrize("hp_filter", _kernels) @pytest.mark.parametrize("window_size", [8, 11]) - @pytest.mark.parametrize("reduction", ['mean', 'none']) + @pytest.mark.parametrize("reduction", ["mean", "none"]) def test_scc_functional(self, preds, target, hp_filter, window_size, reduction): """Test SpatialCorrelationCoefficient functional usage.""" self.run_functional_metric_test( @@ -94,5 +94,5 @@ def test_scc_functional(self, preds, target, hp_filter, window_size, reduction): "hp_filter": hp_filter, "window_size": window_size, "reduction": reduction, - } + }, ) From 5a501ed3d08b9c16e191cddbe515811ab15453f9 Mon Sep 17 00:00:00 2001 From: HoseinAkbarzadeh Date: Mon, 4 Dec 2023 20:32:05 +0000 Subject: [PATCH 15/21] resolved long docstring line --- src/torchmetrics/functional/image/scc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index 478ce3a4b44..467f9208d96 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -178,7 +178,8 @@ def spatial_correlation_coefficient( target: ground truth images of shape ``(N,C,H,W)`` or ``(N,H,W)``. hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]) window_size: Local window size integer. default: 8, - reduction: Reduction method for output tensor. If ``None`` or ``"none"``, returns a tensor with the per sample results. default: ``"mean"``. + reduction: Reduction method for output tensor. If ``None`` or ``"none"``, + returns a tensor with the per sample results. default: ``"mean"``. Return: Tensor with scc score From 38c9f0f7fce28815e7fe64bf8df0e970ab5f777f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Dec 2023 20:40:20 +0000 Subject: [PATCH 16/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/image/scc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index de7d409eb98..68cf346ffaa 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -178,7 +178,7 @@ def spatial_correlation_coefficient( target: ground truth images of shape ``(N,C,H,W)`` or ``(N,H,W)``. hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]) window_size: Local window size integer. default: 8, - reduction: Reduction method for output tensor. If ``None`` or ``"none"``, + reduction: Reduction method for output tensor. If ``None`` or ``"none"``, returns a tensor with the per sample results. default: ``"mean"``. Return: From 650d227e52fe49f8dcdea1768bc2f48b7aaa01d0 Mon Sep 17 00:00:00 2001 From: HoseinAkbarzadeh Date: Tue, 5 Dec 2023 06:19:02 +0000 Subject: [PATCH 17/21] fixed example bug in docstring --- src/torchmetrics/functional/image/scc.py | 6 +++++- src/torchmetrics/image/scc.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index de7d409eb98..f1397f0ccd0 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -194,11 +194,15 @@ def spatial_correlation_coefficient( >>> x = torch.randn(5, 16, 16) >>> scc(x, x) tensor(1.) + >>> x = torch.randn(5, 3, 16, 16) + >>> y = torch.randn(5, 3, 16, 16) + >>> scc(x, y, reduction="none") + tensor([0.0223, 0.0256, 0.0616, 0.0159, 0.0170]) """ if hp_filter is None: hp_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) - if reduce is None: + if reduction is None: reduction = "none" if reduction not in ("mean", "none"): raise ValueError(f"Expected reduction to be 'mean' or 'none', but got {reduction}") diff --git a/src/torchmetrics/image/scc.py b/src/torchmetrics/image/scc.py index 4631b9434ed..15ea2b96ecf 100644 --- a/src/torchmetrics/image/scc.py +++ b/src/torchmetrics/image/scc.py @@ -46,7 +46,7 @@ class SpatialCorrelationCoefficient(Metric): >>> target = torch.randn([32, 3, 64, 64]) >>> scc = SCC() >>> scc(preds, target) - tensor(0.0022) + tensor(0.0023) """ From 45b6af461d87011a4ebe729bb6da6ff1e89614f3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 20 Dec 2023 15:37:57 +0100 Subject: [PATCH 18/21] Update src/torchmetrics/functional/image/scc.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/functional/image/scc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/scc.py b/src/torchmetrics/functional/image/scc.py index 328500b1acf..167ddbd37b5 100644 --- a/src/torchmetrics/functional/image/scc.py +++ b/src/torchmetrics/functional/image/scc.py @@ -74,7 +74,7 @@ def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: i def _symmetric_reflect_pad_2d(input_img: Tensor, pad: Union[int, Tuple[int, ...]]) -> Tensor: - """Applies symmetric padding to the 2D image tensor input using `reflect` mode (d c b a | a b c d | d c b a).""" + """Applies symmetric padding to the 2D image tensor input using ``reflect`` mode (d c b a | a b c d | d c b a).""" if isinstance(pad, int): pad = (pad, pad, pad, pad) if len(pad) != 4: From 6c9308f97845a987484b9e42710abed6260c5838 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 20 Dec 2023 15:38:33 +0100 Subject: [PATCH 19/21] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e51fcd6f4c1..2cacd957a94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `aggregate`` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220)) +- Added `Spatial Correlation Coefficient` to image subpackage ([#2248](https://github.com/Lightning-AI/torchmetrics/pull/2248)) + ### Changed - Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145)) From 3426755ccd064f6dbd280530d4f6e1c36141436d Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 21 Dec 2023 22:09:33 +0100 Subject: [PATCH 20/21] Apply suggestions from code review --- docs/source/links.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/links.rst b/docs/source/links.rst index 9ff27f4c95a..488453b9442 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -167,4 +167,4 @@ .. _FLORES-101: https://arxiv.org/abs/2106.03193 .. _FLORES-200: https://arxiv.org/abs/2207.04672 .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html -.. _SCC: https://www.tandfonline.com/doi/abs/10.1080/014311698215973 +.. _SCC: https://www.tandfonline.com/doi/10.1080/014311698215973 From b4892d567bd1a0772cd1721fea37a151a83af74f Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 21 Dec 2023 22:25:10 +0100 Subject: [PATCH 21/21] link --- docs/source/links.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/links.rst b/docs/source/links.rst index 488453b9442..6ffe815325e 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -167,4 +167,4 @@ .. _FLORES-101: https://arxiv.org/abs/2106.03193 .. _FLORES-200: https://arxiv.org/abs/2207.04672 .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html -.. _SCC: https://www.tandfonline.com/doi/10.1080/014311698215973 +.. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013