Skip to content

Commit

Permalink
Added new image metric - UQI (#824)
Browse files Browse the repository at this point in the history
* Added new metric - UQI
* Registered UQI to functional init; tested locally
* Testcases added for UQI
* Apply suggestions from code review
* Update requirements.txt

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
6 people authored Feb 7, 2022
1 parent 7330e2e commit c0e4250
Show file tree
Hide file tree
Showing 13 changed files with 526 additions and 46 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for `MetricCollection` in `MetricTracker` ([#718](https://github.com/PyTorchLightning/metrics/pull/718))


- Added new image metric `UniversalImageQualityIndex` ([#824](https://github.com/PyTorchLightning/metrics/pull/824))


### Changed


Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@
.. _TER: https://aclanthology.org/2006.amta-papers.25.pdf
.. _ExtendedEditDistance: https://aclanthology.org/W19-5359.pdf
.. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
.. _UniversalImageQualityIndex: https://ieeexplore.ieee.org/document/995823
6 changes: 6 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ peak_signal_noise_ratio [func]
.. autofunction:: torchmetrics.functional.peak_signal_noise_ratio
:noindex:

universal_image_quality_index [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.universal_image_quality_index
:noindex:


**********
Regression
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ StructuralSimilarityIndexMeasure
.. autoclass:: torchmetrics.StructuralSimilarityIndexMeasure
:noindex:

UniversalImageQualityIndex
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.UniversalImageQualityIndex
:noindex:

*********
Detection
*********
Expand Down
174 changes: 174 additions & 0 deletions tests/image/test_uqi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch
from skimage.metrics import structural_similarity

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.image.uqi import UniversalImageQualityIndex

seed_all(42)

# UQI is SSIM with both constants k1 and k2 as 0
skimage_uqi = partial(structural_similarity, k1=0, k2=0)

Input = namedtuple("Input", ["preds", "target", "multichannel"])

_inputs = []
for size, channel, coef, multichannel, dtype in [
(12, 3, 0.9, True, torch.float),
(13, 1, 0.8, False, torch.float32),
(14, 1, 0.7, False, torch.double),
(15, 3, 0.6, True, torch.float64),
]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
_inputs.append(
Input(
preds=preds,
target=preds * coef,
multichannel=multichannel,
)
)


def _sk_uqi(preds, target, data_range, multichannel, kernel_size):
c, h, w = preds.shape[-3:]
sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
if not multichannel:
sk_preds = sk_preds[:, :, :, 0]
sk_target = sk_target[:, :, :, 0]

return skimage_uqi(
sk_target,
sk_preds,
data_range=data_range,
multichannel=multichannel,
gaussian_weights=True,
win_size=kernel_size,
sigma=1.5,
use_sample_covariance=False,
)


@pytest.mark.parametrize(
"preds, target, multichannel",
[(i.preds, i.target, i.multichannel) for i in _inputs],
)
@pytest.mark.parametrize("kernel_size", [5, 11])
class TestUQI(MetricTester):
atol = 6e-3

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_uqi(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
UniversalImageQualityIndex,
partial(_sk_uqi, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
dist_sync_on_step=dist_sync_on_step,
)

def test_uqi_functional(self, preds, target, multichannel, kernel_size):
self.run_functional_metric_test(
preds,
target,
universal_image_quality_index,
partial(_sk_uqi, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
)

# UQI half + cpu does not work due to missing support in torch.log
@pytest.mark.xfail(reason="UQI metric does not support cpu + half precision")
def test_uqi_half_cpu(self, preds, target, multichannel, kernel_size):
self.run_precision_test_cpu(
preds, target, UniversalImageQualityIndex, universal_image_quality_index, {"data_range": 1.0}
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_uqi_half_gpu(self, preds, target, multichannel, kernel_size):
self.run_precision_test_gpu(
preds, target, UniversalImageQualityIndex, universal_image_quality_index, {"data_range": 1.0}
)


@pytest.mark.parametrize(
["pred", "target", "kernel", "sigma"],
[
([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input
([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input
([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input
([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input
],
)
def test_uqi_invalid_inputs(pred, target, kernel, sigma):
pred_t = torch.rand(pred)
target_t = torch.rand(target, dtype=torch.float64)
with pytest.raises(TypeError):
universal_image_quality_index(pred_t, target_t)

pred = torch.rand(pred)
target = torch.rand(target)
with pytest.raises(ValueError):
universal_image_quality_index(pred, target, kernel, sigma)


def test_uqi_unequal_kernel_size():
"""Test the case where kernel_size[0] != kernel_size[1]"""
preds = torch.tensor(
[
[
[
[1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0],
[1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0],
]
]
]
)
target = torch.tensor(
[
[
[
[1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
]
]
]
)
# kernel order matters
torch.allclose(universal_image_quality_index(preds, target, kernel_size=(3, 5)), torch.tensor(0.10662283))
torch.allclose(universal_image_quality_index(preds, target, kernel_size=(5, 3)), torch.tensor(0.10662283))
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
MultiScaleStructuralSimilarityIndexMeasure,
PeakSignalNoiseRatio,
StructuralSimilarityIndexMeasure,
UniversalImageQualityIndex,
)
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.regression import ( # noqa: E402
Expand Down Expand Up @@ -159,6 +160,7 @@
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"TranslationEditRate",
"UniversalImageQualityIndex",
"WordErrorRate",
"CharErrorRate",
"MatchErrorRate",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
)
from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
from torchmetrics.functional.pairwise.linear import pairwise_linear_similarity
Expand Down Expand Up @@ -142,6 +143,7 @@
"stat_scores",
"symmetric_mean_absolute_percentage_error",
"translation_edit_rate",
"universal_image_quality_index",
"word_error_rate",
"char_error_rate",
"match_error_rate",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
)
from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401
50 changes: 50 additions & 0 deletions torchmetrics/functional/image/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Sequence

import torch
from torch import Tensor


def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
"""Computes 1D gaussian kernel.
Args:
kernel_size: size of the gaussian kernel
sigma: Standard deviation of the gaussian kernel
dtype: data type of the output tensor
device: device of the output tensor
Example:
>>> _gaussian(3, 1, torch.float, 'cpu')
tensor([[0.2741, 0.4519, 0.2741]])
"""
dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device)
gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)


def _gaussian_kernel(
channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
"""Computes 2D gaussian kernel.
Args:
channel: number of channels in the image
kernel_size: size of the gaussian kernel as a tuple (h, w)
sigma: Standard deviation of the gaussian kernel
dtype: data type of the output tensor
device: device of the output tensor
Example:
>>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu")
tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030],
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
[0.0219, 0.0983, 0.1621, 0.0983, 0.0219],
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
[0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]])
"""

gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)

return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])
47 changes: 1 addition & 46 deletions torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,56 +18,11 @@
from torch.nn import functional as F
from typing_extensions import Literal

from torchmetrics.functional.image.helper import _gaussian_kernel
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.distributed import reduce


def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
"""Computes 1D gaussian kernel.
Args:
kernel_size: size of the gaussian kernel
sigma: Standard deviation of the gaussian kernel
dtype: data type of the output tensor
device: device of the output tensor
Example:
>>> _gaussian(3, 1, torch.float, 'cpu')
tensor([[0.2741, 0.4519, 0.2741]])
"""
dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device)
gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)


def _gaussian_kernel(
channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
"""Computes 2D gaussian kernel.
Args:
channel: number of channels in the image
kernel_size: size of the gaussian kernel as a tuple (h, w)
sigma: Standard deviation of the gaussian kernel
dtype: data type of the output tensor
device: device of the output tensor
Example:
>>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu")
tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030],
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
[0.0219, 0.0983, 0.1621, 0.0983, 0.0219],
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
[0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]])
"""

gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)

return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])


def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Updates and returns variables required to compute Structural Similarity Index Measure. Checks for same shape
and type of the input tensors.
Expand Down
Loading

0 comments on commit c0e4250

Please sign in to comment.