Skip to content

Commit

Permalink
Add MultiScaleStructuralSimilarityIndexMeasure (#679)
Browse files Browse the repository at this point in the history
* [WIP] Add MS-SSIM

* Fix docs and a mypy issue

* fix badge

* Apply some SkafteNicki's suggestions from code review

* Fix docstring: MS_SSIM -> MultiScaleSSIM

* Fix missed MS_SSIM in a docstring

* flake8

* Add tests

* Update doc example

* Fix seed for doc + update some tests

* Clean code

* Update doc example + make diff test tractable

* torch.manual seed passed directly to torch.rand

* Apply suggestions from code review

* Drop some tests ; fix docs issues + add missing docs

* Update paper.md (#690)

* ci: rename oldest

* CI: set HF caching (#691)

* Apply suggestions: ms_ssim -> multiscale_ssim + typing

* Update doc reference

* Apply SkafteNicki's suggestions

* Update test reference package

* Clean test

* Change naming multiscale_ssim -> multiscale_structual_similarity_index_measure

* Update docs references

* Fix a typo: structual -> structural

* fix one last typo

* Fix ~ len in a doc

Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Maxim Grechkin <maximsch2@gmail.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
  • Loading branch information
6 people authored Jan 10, 2022
1 parent c20860a commit 9288222
Show file tree
Hide file tree
Showing 14 changed files with 475 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `CHRFScore` ([#641](https://github.com/PyTorchLightning/metrics/pull/641))
- `TranslationEditRate` ([#646](https://github.com/PyTorchLightning/metrics/pull/646))

- Added `MultiScaleSSIM` into image metrics ([#679](https://github.com/PyTorchLightning/metrics/pull/679))


- Added a default VSCode devcontainer configuration ([#621](https://github.com/PyTorchLightning/metrics/pull/621))

Expand Down
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
# how much time to give 'run always even if cancelled tasks' before stopping them
cancelTimeoutInMinutes: "2"

pool: azure-gpus-spot
pool: gridai-spot-pool

container:
image: "pytorch/pytorch:1.8.1-cuda10.2-cudnn7-runtime"
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@
.. _chrF score: https://aclanthology.org/W15-3049.pdf
.. _chrF++ score: https://aclanthology.org/W17-4770.pdf
.. _TER: https://aclanthology.org/2006.amta-papers.25.pdf
.. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,13 @@ image_gradients [func]
:noindex:


multiscale_structural_similarity_index_measure [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.multiscale_structural_similarity_index_measure
:noindex:


psnr [func]
~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,12 @@ LPIPS
.. autoclass:: torchmetrics.image.lpip_similarity.LPIPS
:noindex:

MultiScaleStructuralSimilarityIndexMeasure
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.MultiScaleStructuralSimilarityIndexMeasure
:noindex:

PSNR
~~~~

Expand Down
3 changes: 3 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ rouge-score>=0.0.4
bert_score==0.3.10
transformers>=4.0
sacrebleu>=2.0.0

# image
pytorch_msssim
92 changes: 92 additions & 0 deletions tests/image/test_ms_ssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch
from pytorch_msssim import ms_ssim

from tests.helpers import seed_all
from tests.helpers.testers import NUM_BATCHES, MetricTester
from torchmetrics.functional.image.ms_ssim import multiscale_structural_similarity_index_measure
from torchmetrics.image.ms_ssim import MultiScaleStructuralSimilarityIndexMeasure

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])
BATCH_SIZE = 1

_inputs = []
for size, coef in [(128, 0.9), (128, 0.7)]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, 1, size, size)
_inputs.append(
Input(
preds=preds,
target=preds * coef,
)
)


def pytorch_ms_ssim(preds, target, data_range, kernel_size):
return ms_ssim(preds, target, data_range=data_range, win_size=kernel_size)


@pytest.mark.parametrize(
"preds, target",
[(i.preds, i.target) for i in _inputs],
)
@pytest.mark.parametrize("kernel_size", [5, 7])
class TestMultiScaleStructuralSimilarityIndexMeasure(MetricTester):
atol = 6e-3

@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_ms_ssim(self, preds, target, kernel_size, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
MultiScaleStructuralSimilarityIndexMeasure,
partial(pytorch_ms_ssim, data_range=1.0, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
dist_sync_on_step=dist_sync_on_step,
)

def test_ms_ssim_functional(self, preds, target, kernel_size):
self.run_functional_metric_test(
preds,
target,
multiscale_structural_similarity_index_measure,
partial(pytorch_ms_ssim, data_range=1.0, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
)

def test_ms_ssim_differentiability(self, preds, target, kernel_size):
# We need to minimize this example to make the test tractable
single_beta = (1.0,)
_preds = preds[:, :, :, :16, :16]
_target = target[:, :, :, :16, :16]

self.run_differentiability_test(
_preds.type(torch.float64),
_target.type(torch.float64),
metric_functional=multiscale_structural_similarity_index_measure,
metric_module=MultiScaleStructuralSimilarityIndexMeasure,
metric_args={
"data_range": 1.0,
"kernel_size": (kernel_size, kernel_size),
"betas": single_beta,
},
)
3 changes: 2 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
Specificity,
StatScores,
)
from torchmetrics.image import PSNR, SSIM # noqa: E402
from torchmetrics.image import PSNR, SSIM, MultiScaleStructuralSimilarityIndexMeasure # noqa: E402
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.metric_collections import MetricCollection # noqa: E402
from torchmetrics.regression import ( # noqa: E402
Expand Down Expand Up @@ -135,6 +135,7 @@
"MinMaxMetric",
"MinMetric",
"MultioutputWrapper",
"MultiScaleStructuralSimilarityIndexMeasure",
"PearsonCorrcoef",
"PearsonCorrCoef",
"PIT",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from torchmetrics.functional.classification.specificity import specificity
from torchmetrics.functional.classification.stat_scores import stat_scores
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.ms_ssim import multiscale_structural_similarity_index_measure
from torchmetrics.functional.image.psnr import psnr
from torchmetrics.functional.image.ssim import ssim
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity
Expand Down Expand Up @@ -105,6 +106,7 @@
"mean_absolute_percentage_error",
"mean_squared_error",
"mean_squared_log_error",
"multiscale_structural_similarity_index_measure",
"pairwise_cosine_similarity",
"pairwise_euclidean_distance",
"pairwise_linear_similarity",
Expand Down
204 changes: 204 additions & 0 deletions torchmetrics/functional/image/ms_ssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.image.ssim import _ssim_compute, _ssim_update


def _get_normalized_sim_and_cs(
preds: Tensor,
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
normalize: Optional[Literal["relu", "simple"]] = None,
) -> Tuple[Tensor, Tensor]:
sim, contrast_sensitivity = _ssim_compute(
preds, target, kernel_size, sigma, reduction, data_range, k1, k2, return_contrast_sensitivity=True
)
if normalize == "relu":
sim = torch.relu(sim)
contrast_sensitivity = torch.relu(contrast_sensitivity)
return sim, contrast_sensitivity


def _multiscale_ssim_compute(
preds: Tensor,
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
betas: Union[Tuple[float, float, float, float, float], Tuple[float, ...]] = (
0.0448,
0.2856,
0.3001,
0.2363,
0.1333,
),
normalize: Optional[Literal["relu", "simple"]] = None,
) -> Tensor:
"""Computes Multi-Scale Structual Similarity Index Measure. Adapted from: https://github.com/jorge-
pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py.
Args:
preds: estimated image
target: ground truth image
kernel_size: size of the gaussian kernel (default: (11, 11))
sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5))
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of structural similarity index measure.
k2: Parameter of structural similarity index measure.
betas: Exponent parameters for individual similarities and contrastive sensitivies returned by different image
resolutions.
normalize: When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the
training stability. This `normalize` argument is out of scope of the original implementation [1], and it is
adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.
Raises:
ValueError:
If the image height or width is smaller then ``2 ** len(betas)``.
ValueError:
If the image height is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``.
ValueError:
If the image width is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``.
"""
sim_list: List[Tensor] = []
cs_list: List[Tensor] = []

if preds.size()[-1] < 2 ** len(betas) or preds.size()[-2] < 2 ** len(betas):
raise ValueError(
f"For a given number of `betas` parameters {len(betas)}, the image height and width dimensions must be"
f" larger than or equal to {2 ** len(betas)}."
)

_betas_div = max(1, (len(betas) - 1)) ** 2
if preds.size()[-2] // _betas_div <= kernel_size[0] - 1:
raise ValueError(
f"For a given number of `betas` parameters {len(betas)} and kernel size {kernel_size[0]},"
f" the image height must be larger than {(kernel_size[0] - 1) * _betas_div}."
)
if preds.size()[-1] // _betas_div <= kernel_size[1] - 1:
raise ValueError(
f"For a given number of `betas` parameters {len(betas)} and kernel size {kernel_size[1]},"
f" the image width must be larger than {(kernel_size[1] - 1) * _betas_div}."
)

for _ in range(len(betas)):
sim, contrast_sensitivity = _get_normalized_sim_and_cs(
preds, target, kernel_size, sigma, reduction, data_range, k1, k2, normalize
)
sim_list.append(sim)
cs_list.append(contrast_sensitivity)
preds = F.avg_pool2d(preds, (2, 2))
target = F.avg_pool2d(target, (2, 2))

sim_stack = torch.stack(sim_list)
cs_stack = torch.stack(cs_list)

if normalize == "simple":
sim_stack = (sim_stack + 1) / 2
cs_stack = (cs_stack + 1) / 2

sim_stack = sim_stack ** torch.tensor(betas, device=sim_stack.device)
cs_stack = cs_stack ** torch.tensor(betas, device=cs_stack.device)
return torch.prod(cs_stack[:-1]) * sim_stack[-1]


def multiscale_structural_similarity_index_measure(
preds: Tensor,
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
normalize: Optional[Literal["relu", "simple"]] = None,
) -> Tensor:
"""Computes `MultiScaleSSIM`_, Multi-scale Structual Similarity Index Measure, which is a generalization of
Structual Similarity Index Measure by incorporating image details at different resolution scores.
Args:
preds: Predictions from model of shape `[N, C, H, W]`
target: Ground truth values of shape `[N, C, H, W]`
kernel_size: size of the gaussian kernel (default: (11, 11))
sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5))
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of structural similarity index measure.
k2: Parameter of structural similarity index measure.
betas: Exponent parameters for individual similarities and contrastive sensitivies returned by different image
resolutions.
normalize: When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the
training stability. This `normalize` argument is out of scope of the original implementation [1], and it is
adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.
Return:
Tensor with Multi-Scale SSIM score
Raises:
TypeError:
If ``preds`` and ``target`` don't have the same data type.
ValueError:
If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
ValueError:
If the length of ``kernel_size`` or ``sigma`` is not ``2``.
ValueError:
If one of the elements of ``kernel_size`` is not an ``odd positive number``.
ValueError:
If one of the elements of ``sigma`` is not a ``positive number``.
Example:
>>> from torchmetrics.functional import multiscale_structural_similarity_index_measure
>>> preds = torch.rand([1, 1, 256, 256], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> multiscale_structural_similarity_index_measure(preds, target)
tensor(0.9558)
References:
[1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C.
Bovik `MultiScaleSSIM`_
"""
if not isinstance(betas, tuple):
raise ValueError("Argument `betas` is expected to be of a type tuple.")
if isinstance(betas, tuple) and not all(isinstance(beta, float) for beta in betas):
raise ValueError("Argument `betas` is expected to be a tuple of floats.")
if normalize and normalize not in ("relu", "simple"):
raise ValueError("Argument `normalize` to be expected either `None` or one of 'relu' or 'simple'")

preds, target = _ssim_update(preds, target)
return _multiscale_ssim_compute(preds, target, kernel_size, sigma, reduction, data_range, k1, k2, betas, normalize)
Loading

0 comments on commit 9288222

Please sign in to comment.