Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new metric SCC for Images #800

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
223b5ed
Demo Pull Request
nishant42491 Jan 26, 2022
4267723
Merge branch 'PyTorchLightning:master' into image_metric_spatial_corr…
nishant42491 Feb 23, 2022
7d51822
Created function Scc metric
nishant42491 Mar 26, 2022
ba031b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 26, 2022
8927bec
Creating a functional interface for scc metric
nishant42491 Mar 28, 2022
566ee19
Merge remote-tracking branch 'origin/image_metric_spatial_correlation…
nishant42491 Mar 28, 2022
0dd5b9e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
3195592
Refining the documentation for functional scc metric
nishant42491 Mar 28, 2022
89ec481
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
96909be
Merge branch 'master' into image_metric_spatial_correlation_coefficient
nishant42491 Mar 28, 2022
ec1585b
refining docstrings and code to pass tests
nishant42491 Mar 28, 2022
13a8822
Merge remote-tracking branch 'origin/image_metric_spatial_correlation…
nishant42491 Mar 28, 2022
9f52f43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
4541d53
Formatting Python Code
nishant42491 Mar 28, 2022
1adef23
Merge remote-tracking branch 'origin/image_metric_spatial_correlation…
nishant42491 Mar 28, 2022
7778463
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
08310d5
Created a class interface for scc metric
nishant42491 Mar 30, 2022
3dfdeac
Merge remote-tracking branch 'origin/image_metric_spatial_correlation…
nishant42491 Mar 30, 2022
9e6c6aa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2022
513579e
formatting functional code
nishant42491 Mar 30, 2022
cb73488
Merge remote-tracking branch 'origin/image_metric_spatial_correlation…
nishant42491 Mar 30, 2022
34d3334
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2022
c7b5e17
formatting code
nishant42491 Mar 30, 2022
8f841d4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2022
4df86b3
added test_scc.py file
nishant42491 Mar 30, 2022
13875d9
Merge remote-tracking branch 'origin/image_metric_spatial_correlation…
nishant42491 Mar 30, 2022
7eeaa16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2022
f172f74
Merge branch 'master' into image_metric_spatial_correlation_coefficient
nishant42491 Mar 30, 2022
e868739
Merge branch 'master' into image_metric_spatial_correlation_coefficient
SkafteNicki Apr 26, 2022
6da16fd
init files
SkafteNicki Apr 26, 2022
de32c55
changelog
SkafteNicki Apr 26, 2022
ee3ff90
docs
SkafteNicki Apr 26, 2022
ae20a59
remove compute on step
SkafteNicki Apr 26, 2022
8ce899a
implement tests
SkafteNicki Apr 26, 2022
e58c6e8
Merge branch 'master' into image_metric_spatial_correlation_coefficient
nishant42491 Apr 26, 2022
1e68370
try fixing tests
SkafteNicki Apr 27, 2022
b4bb94b
Merge branch 'image_metric_spatial_correlation_coefficient' of https:…
SkafteNicki Apr 27, 2022
9d6449b
Merge branch 'master' into image_metric_spatial_correlation_coefficient
Borda May 5, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added


- Added `SpatialCorrelationCoefficient` to image domain ([#800](https://github.com/PyTorchLightning/metrics/pull/800))


- Added `RetrievalPrecisionRecallCurve` to retrieval package ([#951](https://github.com/PyTorchLightning/metrics/pull/951))


- Added `RetrievalRecallAtFixedPrecision` to retrieval package ([#951](https://github.com/PyTorchLightning/metrics/pull/951))


- Added support for nested metric collections ([#1003](https://github.com/PyTorchLightning/metrics/pull/1003))


Expand Down
22 changes: 22 additions & 0 deletions docs/source/image/spatial_correlation_coefficient.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Spatial Correlation Coefficient (SCC)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

#####################################
Spatial Correlation Coefficient (SCC)
#####################################

Module Interface
________________

.. autoclass:: torchmetrics.SpatialCorrelationCoefficient
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.spatial_correlation_coefficient
:noindex:
1 change: 1 addition & 0 deletions requirements/image_test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
scikit-image>0.17.1
pytorch_msssim
sewar
105 changes: 105 additions & 0 deletions tests/image/test_scc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch
from sewar.full_ref import scc

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional.image.scc import spatial_correlation_coefficient
from torchmetrics.image.scc import SpatialCorrelationCoefficient

seed_all(42)


def _reference_scc(preds, target, reduction):
val = 0.0
for p, t in zip(preds, target):
val += scc(t.permute(1, 2, 0).numpy(), p.permute(1, 2, 0).numpy(), ws=9)
val = val if reduction == "sum" else val / preds.shape[0]
return val


Input = namedtuple("Input", ["preds", "target"])

_inputs = []
for size, channel, dtype in [
(12, 3, torch.uint8),
(13, 3, torch.float32),
(14, 3, torch.double),
(15, 3, torch.float64),
]:
preds = torch.randint(0, 255, (NUM_BATCHES, BATCH_SIZE, channel, size, size), dtype=dtype)
target = torch.randint(0, 255, (NUM_BATCHES, BATCH_SIZE, channel, size, size), dtype=dtype)
_inputs.append(Input(preds=preds, target=target))


@pytest.mark.parametrize("reduction", ["sum", "elementwise_mean"])
@pytest.mark.parametrize(
"preds, target",
[(i.preds, i.target) for i in _inputs],
)
class TestSpatialCorrelationCoefficient(MetricTester):
atol = 1e-3

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_scc(self, reduction, preds, target, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
SpatialCorrelationCoefficient,
partial(_reference_scc, reduction=reduction),
dist_sync_on_step,
metric_args=dict(reduction=reduction),
)

def test_scc_functional(self, reduction, preds, target):
self.run_functional_metric_test(
preds,
target,
spatial_correlation_coefficient,
partial(_reference_scc, reduction=reduction),
metric_args=dict(reduction=reduction),
)

# SAM half + cpu does not work due to missing support in torch.log
@pytest.mark.xfail(reason="SCC metric does not support cpu + half precision")
def test_scc_half_cpu(self, reduction, preds, target):
self.run_precision_test_cpu(
preds,
target,
SpatialCorrelationCoefficient,
spatial_correlation_coefficient,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_scc_half_gpu(self, reduction, preds, target):
self.run_precision_test_gpu(preds, target, SpatialCorrelationCoefficient, spatial_correlation_coefficient)


def test_error_on_different_shape(metric_class=SpatialCorrelationCoefficient):
metric = metric_class()
with pytest.raises(RuntimeError):
metric(torch.randn([1, 3, 16, 16]), torch.randn([1, 1, 16, 16]))


def test_error_on_invalid_shape(metric_class=SpatialCorrelationCoefficient):
metric = metric_class()
with pytest.raises(ValueError):
metric(torch.randn([3, 16, 16]), torch.randn([3, 16, 16]))
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
ErrorRelativeGlobalDimensionlessSynthesis,
MultiScaleStructuralSimilarityIndexMeasure,
PeakSignalNoiseRatio,
SpatialCorrelationCoefficient,
SpectralAngleMapper,
SpectralDistortionIndex,
StructuralSimilarityIndexMeasure,
Expand Down Expand Up @@ -176,6 +177,7 @@
"ScaleInvariantSignalDistortionRatio",
"ScaleInvariantSignalNoiseRatio",
"SignalNoiseRatio",
"SpatialCorrelationCoefficient",
"SpearmanCorrCoef",
"Specificity",
"SpectralAngleMapper",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.functional.image.scc import spatial_correlation_coefficient
from torchmetrics.functional.image.ssim import (
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
Expand Down Expand Up @@ -154,6 +155,7 @@
"scale_invariant_signal_noise_ratio",
"signal_noise_ratio",
"spearman_corrcoef",
"spatial_correlation_coefficient",
"specificity",
"spectral_distortion_index",
"squad",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchmetrics.functional.image.gradients import image_gradients # noqa: F401
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio # noqa: F401
from torchmetrics.functional.image.sam import spectral_angle_mapper # noqa: F401
from torchmetrics.functional.image.scc import spatial_correlation_coefficient # noqa: F401
from torchmetrics.functional.image.ssim import ( # noqa: F401
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
Expand Down
200 changes: 200 additions & 0 deletions torchmetrics/functional/image/scc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
from torch.nn import functional as F
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.distributed import reduce


def _scc_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute Spatial Correlation Coefficient.

Checks for same shape and
type of the input tensors.
Args:
preds: Predicted tensor
target: Ground truth tensor
"""

if preds.dtype != target.dtype:
raise TypeError(
"Expected `preds` and `target` to have the same data type."
f" Got preds: {preds.dtype} and target: {target.dtype}."
)
_check_same_shape(preds, target)
if len(preds.shape) != 4:
raise ValueError(
"Expected `preds` and `target` to have BxCxHxW shape."
f" Got preds: {preds.shape} and target: {target.shape}."
)
return preds, target


def _scc_compute(
preds: Tensor,
target: Tensor,
kernel_size: Sequence[int] = (9, 9),
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
) -> Union[Tensor, Tuple[Tensor, Tensor]]:

"""Args:
preds: estimated image
target: ground truth image
kernel_size: size of the Uniform kernel (default: (9, 9))

reduction: a method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied


Return:
Tensor with Spatial Correlation Coefficient score

Raises:
TypeError:
If ``preds`` and ``target`` don't have the same data type.
ValueError:
If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
ValueError:
If the length of ``kernel_size`` is not ``2``.
ValueError:
If one of the elements of ``kernel_size`` is not an ``odd positive number``."""

if len(kernel_size) != 2:
raise ValueError(
"Expected `kernel_size` and `sigma` to have the length of two." f" Got kernel_size: {len(kernel_size)}."
)

if any(x % 2 == 0 or x <= 0 for x in kernel_size):
raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")

batch_size = preds.shape[0]

classes = preds.shape[1]

coefs = torch.zeros((batch_size, classes, preds.shape[2], preds.shape[3]))

kernel = torch.div(
torch.ones(1, 1, kernel_size[0], kernel_size[1], dtype=preds.dtype, device=preds.device),
kernel_size[0] * kernel_size[1],
)

for i in range(classes):

mu1, mu2 = F.conv2d(preds[:, i, :, :].unsqueeze(1), kernel, padding="same"), F.conv2d(
target[:, i, :, :].unsqueeze(1), kernel, padding="same"
)

preds_sum_sq, target_sum_sq, preds_target_sum_mul = mu1 * mu1, mu2 * mu2, mu1 * mu2

outputs_1 = F.conv2d(
preds[:, i, :, :].unsqueeze(1) * preds[:, i, :, :].unsqueeze(1),
kernel,
padding="same",
)
outputs_1 -= preds_sum_sq

outputs_2 = F.conv2d(
target[:, i, :, :].unsqueeze(1) * target[:, i, :, :].unsqueeze(1),
kernel,
padding="same",
)
outputs_2 -= target_sum_sq

outputs_3 = F.conv2d(
preds[:, i, :, :].unsqueeze(1) * target[:, i, :, :].unsqueeze(1),
kernel,
padding="same",
)
outputs_3 -= preds_target_sum_mul

sigma_preds_sq, sigma_target_sq, sigma_preds_target = outputs_1, outputs_2, outputs_3

sigma_preds_sq[sigma_preds_sq < 0] = 0
sigma_target_sq[sigma_target_sq < 0] = 0

den = torch.sqrt(sigma_preds_sq) * torch.sqrt(sigma_target_sq)

idx = den == 0

den[den == 0] = 1

scc = sigma_preds_target / den

scc[idx] = 0

coefs[:, i, :, :] = scc.squeeze(1)

batch_score = []
for i in range(scc.shape[0]):
batch_score.append(torch.mean(scc[i, :, :, :]))

final_batch_score = torch.as_tensor(batch_score)

final_batch_score = reduce(final_batch_score, reduction=reduction)

return final_batch_score


def spatial_correlation_coefficient(
preds: Tensor,
target: Tensor,
kernel_size: Sequence[int] = (9, 9),
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
) -> Tensor:
"""Spatial Correlation Coefficient.

Args:
preds: estimated image
target: ground truth image
kernel_size: size of the Uniform kernel (default: (9, 9))

reduction: a method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

Return:
Tensor with Spatial Correlation Coefficient score

Raises:
TypeError:
If ``preds`` and ``target`` don't have the same data type.
ValueError:
If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
ValueError:
If the length of ``kernel_size`` is not ``2``.
ValueError:
If one of the elements of ``kernel_size`` is not an ``odd positive number``.

Example:
>>> from torchmetrics.functional.image.scc import spatial_correlation_coefficient
>>> preds = torch.ones([16, 3, 16, 16])
>>> target = torch.ones([16, 3, 16, 16])
>>> spatial_correlation_coefficient(preds, target)
tensor(1.)
"""

preds, target = _scc_update(preds, target)

return _scc_compute(preds, target, kernel_size, reduction)
1 change: 1 addition & 0 deletions torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis # noqa: F401
from torchmetrics.image.psnr import PeakSignalNoiseRatio # noqa: F401
from torchmetrics.image.sam import SpectralAngleMapper # noqa: F401
from torchmetrics.image.scc import SpatialCorrelationCoefficient # noqa: F401
from torchmetrics.image.ssim import ( # noqa: F401
MultiScaleStructuralSimilarityIndexMeasure,
StructuralSimilarityIndexMeasure,
Expand Down
Loading