Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Spatial Correlation Coefficient (SCC) metric #2248

Merged
merged 31 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c3649ca
SpatialCorrelationCoefficient functionality and module added.
HoseinAkbarzadeh Nov 29, 2023
be3da87
tests for spatial correlation coefficient (scc) added.
HoseinAkbarzadeh Nov 29, 2023
d678f2a
documentation for spatial correlation coefficient (scc) updated.
HoseinAkbarzadeh Nov 29, 2023
952233b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 29, 2023
48164a4
Apply suggestions from code review
Borda Nov 30, 2023
2637b23
scc functional docstrings added. _hp_2d_laplacian function updated
HoseinAkbarzadeh Nov 30, 2023
4a0058a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2023
0668236
required failed checks resolved.
HoseinAkbarzadeh Nov 30, 2023
f231192
fixing merge conflict
HoseinAkbarzadeh Nov 30, 2023
24bfc77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2023
33b6d90
fixing the variable name mistake.
HoseinAkbarzadeh Nov 30, 2023
f78db34
merge conflict resovled.
HoseinAkbarzadeh Nov 30, 2023
52bdb76
Merge branch 'master' into master
SkafteNicki Nov 30, 2023
f5a874f
fixed even window size bug. changed atol to 1e-8. added None reductio…
HoseinAkbarzadeh Dec 4, 2023
a9c25ee
added new tests for scc functional interface'
HoseinAkbarzadeh Dec 4, 2023
2d9300b
resolving failed mypy checks.
HoseinAkbarzadeh Dec 4, 2023
58b3935
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2023
5a501ed
resolved long docstring line
HoseinAkbarzadeh Dec 4, 2023
7df880f
merge fix
HoseinAkbarzadeh Dec 4, 2023
38c9f0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2023
650d227
fixed example bug in docstring
HoseinAkbarzadeh Dec 5, 2023
7605971
fixing merge conflict
HoseinAkbarzadeh Dec 5, 2023
3e577ee
Merge branch 'master' into master
HoseinAkbarzadeh Dec 5, 2023
f423c0c
Merge branch 'master' into master
SkafteNicki Dec 20, 2023
45b6af4
Update src/torchmetrics/functional/image/scc.py
SkafteNicki Dec 20, 2023
6c9308f
changelog
SkafteNicki Dec 20, 2023
9a6f888
Merge branch 'master' into master
mergify[bot] Dec 21, 2023
b5b9702
Merge branch 'master' into HoseinAkbarzadeh/master
Borda Dec 21, 2023
3426755
Apply suggestions from code review
Borda Dec 21, 2023
b4892d5
link
Borda Dec 21, 2023
5e18e82
Merge branch 'master' into master
SkafteNicki Dec 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/source/image/spatial_correlation_coefficient.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Spatial Correlation Coefficient (SCC)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

#################################
Spatial Correlation Coefficient (SCC)
#################################
Borda marked this conversation as resolved.
Show resolved Hide resolved

Module Interface
________________

.. autoclass:: torchmetrics.image.SpatialCorrelationCoefficient
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.spatial_correlation_coefficient
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@
.. _FLORES-101: https://arxiv.org/abs/2106.03193
.. _FLORES-200: https://arxiv.org/abs/2207.04672
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _SCC: https://www.tandfonline.com/doi/abs/10.1080/014311698215973
Borda marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.functional.image.rase import relative_average_spectral_error
from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.functional.image.scc import spatial_correlation_coefficient
from torchmetrics.functional.image.ssim import (
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
Expand All @@ -45,4 +46,5 @@
"visual_information_fidelity",
"learned_perceptual_image_patch_similarity",
"perceptual_path_length",
"spatial_correlation_coefficient",
]
177 changes: 177 additions & 0 deletions src/torchmetrics/functional/image/scc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from typing import Tuple, Union
Borda marked this conversation as resolved.
Show resolved Hide resolved

import torch
from torch import Tensor, tensor
from torch.nn.functional import conv2d

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.distributed import reduce


def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tuple[Tensor, Tensor, Tensor]:
"""Update and returns variables required to compute Spatial Correlation Coefficient.

Args:
preds: Predicted tensor
target: Ground truth tensor
hp_filter: High-pass filter tensor
window_size: Local window size integer

Return:
Tuple of (preds, target, hp_filter) tensors

Raises:
ValueError:
If ``preds`` and ``target`` have different number of channels
If ``preds`` and ``target`` have different shapes
If ``preds`` and ``target`` have invalid shapes
If ``window_size`` is not a positive integer
If ``window_size`` is greater than the size of the image

"""
if preds.dtype != target.dtype:
target = target.to(preds.dtype)
_check_same_shape(preds, target)
if len(preds.shape) not in (3, 4):
Borda marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Expected `preds` and `target` to have batch of colored images with BxCxHxW shape"
" or batch of grayscale images of BxHxW shape."
f" Got preds: {preds.shape} and target: {target.shape}."
)

if len(preds.shape) == 3:
preds = preds.unsqueeze(1)
target = target.unsqueeze(1)

if not window_size > 0:
raise ValueError(f"Expected `window_size` to be a positive integer. Got {window_size}.")

if window_size > preds.size(2) or window_size > preds.size(3):
raise ValueError(
f"Expected `window_size` to be less than or equal to the size of the image."
f" Got window_size: {window_size} and image size: {preds.size(2)}x{preds.size(3)}."
)

preds = preds.to(torch.float32)
target = target.to(torch.float32)
hp_filter = hp_filter[None, None, :].to(dtype=preds.dtype, device=preds.device)
return preds, target, hp_filter


def _symmetric_reflect_pad_2d(input: Tensor, pad: Union[int, Tuple[int, ...]]) -> Tensor:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(pad, int):
pad = (pad, pad, pad, pad)
assert len(pad) == 4
Borda marked this conversation as resolved.
Show resolved Hide resolved

left_pad = input[:, :, :, 0 : pad[0]].flip(dims=[3])
right_pad = input[:, :, :, -pad[1] :].flip(dims=[3])
padded = torch.cat([left_pad, input, right_pad], dim=3)

top_pad = padded[:, :, 0 : pad[2], :].flip(dims=[2])
bottom_pad = padded[:, :, -pad[3] :, :].flip(dims=[2])
return torch.cat([top_pad, padded, bottom_pad], dim=2)


def _signal_convolve_2d(input: Tensor, kernel: Tensor) -> Tensor:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
left_padding = int(torch.floor(tensor((kernel.size(3) - 1) / 2)).item())
right_padding = int(torch.ceil(tensor((kernel.size(3) - 1) / 2)).item())
top_padding = int(torch.floor(tensor((kernel.size(2) - 1) / 2)).item())
bottom_padding = int(torch.ceil(tensor((kernel.size(2) - 1) / 2)).item())
Borda marked this conversation as resolved.
Show resolved Hide resolved

padded = _symmetric_reflect_pad_2d(input, pad=(left_padding, right_padding, top_padding, bottom_padding))
kernel = kernel.flip([2, 3])
return conv2d(padded, kernel, stride=1, padding=0)


def _hp_2d_laplacian(input: Tensor, kernel: Tensor) -> Tensor:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
output = _signal_convolve_2d(input, kernel)
output += _signal_convolve_2d(input, kernel)
return output


def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved
preds_mean = conv2d(preds, window, stride=1, padding="same")
target_mean = conv2d(target, window, stride=1, padding="same")

preds_var = conv2d(preds**2, window, stride=1, padding="same") - preds_mean**2
target_var = conv2d(target**2, window, stride=1, padding="same") - target_mean**2
target_preds_cov = conv2d(target * preds, window, stride=1, padding="same") - target_mean * preds_mean

return preds_var, target_var, target_preds_cov


def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int):
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""Computes per channel Spatial Correlation Coefficient.

Args:
preds: estimated image of Bx1xHxW shape.
target: ground truth image of Bx1xHxW shape.
hp_filter: 2D high-pass filter.
window_size: size of window for local mean calculation.

Return:
Tensor with Spatial Correlation Coefficient score

"""
dtype = preds.dtype
device = preds.device

# This code is inspired by
# https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187.

window = torch.ones(size=(1, 1, window_size, window_size), dtype=dtype, device=device) / (window_size**2)

preds_hp = _hp_2d_laplacian(preds, hp_filter)
target_hp = _hp_2d_laplacian(target, hp_filter)

preds_var, target_var, target_preds_cov = _local_variance_covariance(preds_hp, target_hp, window)

preds_var[preds_var < 0] = 0
target_var[target_var < 0] = 0

den = torch.sqrt(target_var) * torch.sqrt(preds_var)
idx = den == 0
den[den == 0] = 1
scc = target_preds_cov / den
scc[idx] = 0
return scc


def spatial_correlation_coefficient(
preds: Tensor,
target: Tensor,
hp_filter: Tensor = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]),
window_size: int = 8,
):
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""Compute Spatial Correlation Coefficient (SCC_).

Args:
preds: predicted images of shape ``(N,C,H,W)`` or ``(N,H,W)``.
target: ground truth images of shape ``(N,C,H,W)`` or ``(N,H,W)``.
hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]])
window_size: Local window size integer. default: 8

Return:
Tensor with scc score

Example:
>>> import torch
>>> from torchmetrics.functional.image import spatial_correlation_coefficient as scc
>>> _ = torch.manual_seed(42)
>>> x = torch.randn(5, 3, 16, 16)
>>> scc(x, x)
tensor(1.)
>>> x = torch.randn(5, 16, 16)
>>> scc(x, x)
tensor(1.)

"""
preds, target, hp_filter = _scc_update(preds, target, hp_filter, window_size)

per_channel = [
_scc_per_channel_compute(
preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, window_size
)
for i in range(preds.size(1))
]
return reduce(torch.cat(per_channel, dim=1), reduction="elementwise_mean")
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics.image.rase import RelativeAverageSpectralError
from torchmetrics.image.rmse_sw import RootMeanSquaredErrorUsingSlidingWindow
from torchmetrics.image.sam import SpectralAngleMapper
from torchmetrics.image.scc import SpatialCorrelationCoefficient
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure
from torchmetrics.image.tv import TotalVariation
from torchmetrics.image.uqi import UniversalImageQualityIndex
Expand All @@ -42,6 +43,7 @@
"UniversalImageQualityIndex",
"VisualInformationFidelity",
"TotalVariation",
"SpatialCorrelationCoefficient",
]

if _TORCH_FIDELITY_AVAILABLE:
Expand Down
73 changes: 73 additions & 0 deletions src/torchmetrics/image/scc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any
Borda marked this conversation as resolved.
Show resolved Hide resolved

import torch
from torch import Tensor, tensor

from torchmetrics.functional.image.scc import _scc_per_channel_compute as _scc_compute
from torchmetrics.functional.image.scc import _scc_update
from torchmetrics.metric import Metric


class SpatialCorrelationCoefficient(Metric):
"""Compute Spatial Correlation Coefficient (SCC_).

As input to ``forward`` and ``update`` the metric accepts the following input

- ``preds`` (:class:`~torch.Tensor`): Predictions from model of shape ``(N,C,H,W)`` or ``(N,H,W)``.
- ``target`` (:class:`~torch.Tensor`): Ground truth values of shape ``(N,C,H,W)`` or ``(N,H,W)``.

As output of `forward` and `compute` the metric returns the following output

- ``scc`` (:class:`~torch.Tensor`): Tensor with scc score

Args:
hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]).
window_size: Local window size integer. default: 8.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import SpatialCorrelationCoefficient as SCC
>>> preds = torch.randn([32, 3, 64, 64])
>>> target = torch.randn([32, 3, 64, 64])
>>> scc = SCC()
>>> scc(preds, target)
tensor(0.0022)

"""

is_differentiable = True
higher_is_better = True
full_state_update = False

scc_score: Tensor
total: Tensor

def __init__(
self,
high_pass_filter: Tensor = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]),
window_size: int = 11,
**kwargs: Any
) -> None:
super().__init__(**kwargs)

self.hp_filter = high_pass_filter
self.ws = window_size

self.add_state("scc_score", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
preds, target, hp_filter = _scc_update(preds, target, self.hp_filter, self.ws)
scc_per_channel = [
_scc_compute(preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, self.ws)
for i in range(preds.size(1))
]
self.scc_score += torch.sum(torch.mean(torch.cat(scc_per_channel, dim=1), dim=[1, 2, 3]))
self.total += preds.size(0)

def compute(self) -> Tensor:
"""Compute the VIF score based on inputs passed in to ``update`` previously."""
return self.scc_score / self.total
56 changes: 56 additions & 0 deletions tests/unittests/image/test_scc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from collections import namedtuple
Borda marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np
import pytest
import torch
from sewar.full_ref import scc as sewar_scc
from torchmetrics.functional.image import spatial_correlation_coefficient
from torchmetrics.image import SpatialCorrelationCoefficient

from unittests import BATCH_SIZE, NUM_BATCHES
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])
_inputs = [
Input(
preds=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128),
target=torch.randn(NUM_BATCHES, BATCH_SIZE, channels, 128, 128),
)
for channels in [1, 3]
]


def _reference_scc(preds, target):
"""Reference implementation of scc from sewar."""
preds = torch.movedim(preds, 1, -1)
target = torch.movedim(target, 1, -1)
preds = preds.cpu().numpy()
target = target.cpu().numpy()
hp_filter = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
window_size = 8
scc = [
sewar_scc(GT=target[batch], P=preds[batch], win=hp_filter, ws=window_size) for batch in range(preds.shape[0])
]
return np.mean(scc)


@pytest.mark.parametrize("preds, target", [(i.preds, i.target) for i in _inputs])
class TestSpatialCorrelationCoefficient(MetricTester):
atol = 1e-3

@pytest.mark.parametrize("ddp", [True, False])
def test_scc(self, preds, target, ddp):
self.run_class_metric_test(
ddp, preds, target, metric_class=SpatialCorrelationCoefficient, reference_metric=_reference_scc
)

def test_scc_functional(self, preds, target):
self.run_functional_metric_test(
preds,
target,
metric_functional=spatial_correlation_coefficient,
reference_metric=_reference_scc,
)
Loading