From 9288222e9ab5124acfd74c1bd31d702b61b6679a Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Mon, 10 Jan 2022 10:14:18 +0100 Subject: [PATCH] Add `MultiScaleStructuralSimilarityIndexMeasure` (#679) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [WIP] Add MS-SSIM * Fix docs and a mypy issue * fix badge * Apply some SkafteNicki's suggestions from code review * Fix docstring: MS_SSIM -> MultiScaleSSIM * Fix missed MS_SSIM in a docstring * flake8 * Add tests * Update doc example * Fix seed for doc + update some tests * Clean code * Update doc example + make diff test tractable * torch.manual seed passed directly to torch.rand * Apply suggestions from code review * Drop some tests ; fix docs issues + add missing docs * Update paper.md (#690) * ci: rename oldest * CI: set HF caching (#691) * Apply suggestions: ms_ssim -> multiscale_ssim + typing * Update doc reference * Apply SkafteNicki's suggestions * Update test reference package * Clean test * Change naming multiscale_ssim -> multiscale_structual_similarity_index_measure * Update docs references * Fix a typo: structual -> structural * fix one last typo * Fix ~ len in a doc Co-authored-by: Jirka Co-authored-by: Jirka Borovec Co-authored-by: Adrian Wälchli Co-authored-by: Maxim Grechkin Co-authored-by: Nicki Skafte Detlefsen --- CHANGELOG.md | 2 + azure-pipelines.yml | 2 +- docs/source/links.rst | 1 + docs/source/references/functional.rst | 7 + docs/source/references/modules.rst | 6 + requirements/test.txt | 3 + tests/image/test_ms_ssim.py | 92 ++++++++ torchmetrics/__init__.py | 3 +- torchmetrics/functional/__init__.py | 2 + torchmetrics/functional/image/ms_ssim.py | 204 ++++++++++++++++++ torchmetrics/functional/image/ssim.py | 10 +- .../functional/regression/__init__.py | 1 + torchmetrics/image/__init__.py | 1 + torchmetrics/image/ms_ssim.py | 145 +++++++++++++ 14 files changed, 475 insertions(+), 4 deletions(-) create mode 100644 tests/image/test_ms_ssim.py create mode 100644 torchmetrics/functional/image/ms_ssim.py create mode 100644 torchmetrics/image/ms_ssim.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e4ffe33608a..efe200df024 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `CHRFScore` ([#641](https://github.com/PyTorchLightning/metrics/pull/641)) - `TranslationEditRate` ([#646](https://github.com/PyTorchLightning/metrics/pull/646)) +- Added `MultiScaleSSIM` into image metrics ([#679](https://github.com/PyTorchLightning/metrics/pull/679)) + - Added a default VSCode devcontainer configuration ([#621](https://github.com/PyTorchLightning/metrics/pull/621)) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 01a15710b10..85abb916760 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -23,7 +23,7 @@ jobs: # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: "2" - pool: azure-gpus-spot + pool: gridai-spot-pool container: image: "pytorch/pytorch:1.8.1-cuda10.2-cudnn7-runtime" diff --git a/docs/source/links.rst b/docs/source/links.rst index a24f4288251..e2839e34fa2 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -76,3 +76,4 @@ .. _chrF score: https://aclanthology.org/W15-3049.pdf .. _chrF++ score: https://aclanthology.org/W17-4770.pdf .. _TER: https://aclanthology.org/2006.amta-papers.25.pdf +.. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index cdc63dc43d7..fd07d9c8d1d 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -241,6 +241,13 @@ image_gradients [func] :noindex: +multiscale_structural_similarity_index_measure [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.multiscale_structural_similarity_index_measure + :noindex: + + psnr [func] ~~~~~~~~~~~ diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 4c83d89d7a5..797d5ce74b3 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -392,6 +392,12 @@ LPIPS .. autoclass:: torchmetrics.image.lpip_similarity.LPIPS :noindex: +MultiScaleStructuralSimilarityIndexMeasure +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.MultiScaleStructuralSimilarityIndexMeasure + :noindex: + PSNR ~~~~ diff --git a/requirements/test.txt b/requirements/test.txt index c78e1e9a453..29ebfb02903 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -35,3 +35,6 @@ rouge-score>=0.0.4 bert_score==0.3.10 transformers>=4.0 sacrebleu>=2.0.0 + +# image +pytorch_msssim diff --git a/tests/image/test_ms_ssim.py b/tests/image/test_ms_ssim.py new file mode 100644 index 00000000000..f5ec4435fb4 --- /dev/null +++ b/tests/image/test_ms_ssim.py @@ -0,0 +1,92 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from pytorch_msssim import ms_ssim + +from tests.helpers import seed_all +from tests.helpers.testers import NUM_BATCHES, MetricTester +from torchmetrics.functional.image.ms_ssim import multiscale_structural_similarity_index_measure +from torchmetrics.image.ms_ssim import MultiScaleStructuralSimilarityIndexMeasure + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +BATCH_SIZE = 1 + +_inputs = [] +for size, coef in [(128, 0.9), (128, 0.7)]: + preds = torch.rand(NUM_BATCHES, BATCH_SIZE, 1, size, size) + _inputs.append( + Input( + preds=preds, + target=preds * coef, + ) + ) + + +def pytorch_ms_ssim(preds, target, data_range, kernel_size): + return ms_ssim(preds, target, data_range=data_range, win_size=kernel_size) + + +@pytest.mark.parametrize( + "preds, target", + [(i.preds, i.target) for i in _inputs], +) +@pytest.mark.parametrize("kernel_size", [5, 7]) +class TestMultiScaleStructuralSimilarityIndexMeasure(MetricTester): + atol = 6e-3 + + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_ms_ssim(self, preds, target, kernel_size, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + MultiScaleStructuralSimilarityIndexMeasure, + partial(pytorch_ms_ssim, data_range=1.0, kernel_size=kernel_size), + metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)}, + dist_sync_on_step=dist_sync_on_step, + ) + + def test_ms_ssim_functional(self, preds, target, kernel_size): + self.run_functional_metric_test( + preds, + target, + multiscale_structural_similarity_index_measure, + partial(pytorch_ms_ssim, data_range=1.0, kernel_size=kernel_size), + metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)}, + ) + + def test_ms_ssim_differentiability(self, preds, target, kernel_size): + # We need to minimize this example to make the test tractable + single_beta = (1.0,) + _preds = preds[:, :, :, :16, :16] + _target = target[:, :, :, :16, :16] + + self.run_differentiability_test( + _preds.type(torch.float64), + _target.type(torch.float64), + metric_functional=multiscale_structural_similarity_index_measure, + metric_module=MultiScaleStructuralSimilarityIndexMeasure, + metric_args={ + "data_range": 1.0, + "kernel_size": (kernel_size, kernel_size), + "betas": single_beta, + }, + ) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 183ba3cdcba..c85399b39b8 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -52,7 +52,7 @@ Specificity, StatScores, ) -from torchmetrics.image import PSNR, SSIM # noqa: E402 +from torchmetrics.image import PSNR, SSIM, MultiScaleStructuralSimilarityIndexMeasure # noqa: E402 from torchmetrics.metric import Metric # noqa: E402 from torchmetrics.metric_collections import MetricCollection # noqa: E402 from torchmetrics.regression import ( # noqa: E402 @@ -135,6 +135,7 @@ "MinMaxMetric", "MinMetric", "MultioutputWrapper", + "MultiScaleStructuralSimilarityIndexMeasure", "PearsonCorrcoef", "PearsonCorrCoef", "PIT", diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 4e8252a1688..64ce23b121e 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -37,6 +37,7 @@ from torchmetrics.functional.classification.specificity import specificity from torchmetrics.functional.classification.stat_scores import stat_scores from torchmetrics.functional.image.gradients import image_gradients +from torchmetrics.functional.image.ms_ssim import multiscale_structural_similarity_index_measure from torchmetrics.functional.image.psnr import psnr from torchmetrics.functional.image.ssim import ssim from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity @@ -105,6 +106,7 @@ "mean_absolute_percentage_error", "mean_squared_error", "mean_squared_log_error", + "multiscale_structural_similarity_index_measure", "pairwise_cosine_similarity", "pairwise_euclidean_distance", "pairwise_linear_similarity", diff --git a/torchmetrics/functional/image/ms_ssim.py b/torchmetrics/functional/image/ms_ssim.py new file mode 100644 index 00000000000..6ce5f83696a --- /dev/null +++ b/torchmetrics/functional/image/ms_ssim.py @@ -0,0 +1,204 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.image.ssim import _ssim_compute, _ssim_update + + +def _get_normalized_sim_and_cs( + preds: Tensor, + target: Tensor, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, + k1: float = 0.01, + k2: float = 0.03, + normalize: Optional[Literal["relu", "simple"]] = None, +) -> Tuple[Tensor, Tensor]: + sim, contrast_sensitivity = _ssim_compute( + preds, target, kernel_size, sigma, reduction, data_range, k1, k2, return_contrast_sensitivity=True + ) + if normalize == "relu": + sim = torch.relu(sim) + contrast_sensitivity = torch.relu(contrast_sensitivity) + return sim, contrast_sensitivity + + +def _multiscale_ssim_compute( + preds: Tensor, + target: Tensor, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, + k1: float = 0.01, + k2: float = 0.03, + betas: Union[Tuple[float, float, float, float, float], Tuple[float, ...]] = ( + 0.0448, + 0.2856, + 0.3001, + 0.2363, + 0.1333, + ), + normalize: Optional[Literal["relu", "simple"]] = None, +) -> Tensor: + """Computes Multi-Scale Structual Similarity Index Measure. Adapted from: https://github.com/jorge- + pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py. + + Args: + preds: estimated image + target: ground truth image + kernel_size: size of the gaussian kernel (default: (11, 11)) + sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + data_range: Range of the image. If ``None``, it is determined from the image (max - min) + k1: Parameter of structural similarity index measure. + k2: Parameter of structural similarity index measure. + betas: Exponent parameters for individual similarities and contrastive sensitivies returned by different image + resolutions. + normalize: When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the + training stability. This `normalize` argument is out of scope of the original implementation [1], and it is + adapted from https://github.com/jorge-pessoa/pytorch-msssim instead. + + Raises: + ValueError: + If the image height or width is smaller then ``2 ** len(betas)``. + ValueError: + If the image height is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``. + ValueError: + If the image width is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``. + """ + sim_list: List[Tensor] = [] + cs_list: List[Tensor] = [] + + if preds.size()[-1] < 2 ** len(betas) or preds.size()[-2] < 2 ** len(betas): + raise ValueError( + f"For a given number of `betas` parameters {len(betas)}, the image height and width dimensions must be" + f" larger than or equal to {2 ** len(betas)}." + ) + + _betas_div = max(1, (len(betas) - 1)) ** 2 + if preds.size()[-2] // _betas_div <= kernel_size[0] - 1: + raise ValueError( + f"For a given number of `betas` parameters {len(betas)} and kernel size {kernel_size[0]}," + f" the image height must be larger than {(kernel_size[0] - 1) * _betas_div}." + ) + if preds.size()[-1] // _betas_div <= kernel_size[1] - 1: + raise ValueError( + f"For a given number of `betas` parameters {len(betas)} and kernel size {kernel_size[1]}," + f" the image width must be larger than {(kernel_size[1] - 1) * _betas_div}." + ) + + for _ in range(len(betas)): + sim, contrast_sensitivity = _get_normalized_sim_and_cs( + preds, target, kernel_size, sigma, reduction, data_range, k1, k2, normalize + ) + sim_list.append(sim) + cs_list.append(contrast_sensitivity) + preds = F.avg_pool2d(preds, (2, 2)) + target = F.avg_pool2d(target, (2, 2)) + + sim_stack = torch.stack(sim_list) + cs_stack = torch.stack(cs_list) + + if normalize == "simple": + sim_stack = (sim_stack + 1) / 2 + cs_stack = (cs_stack + 1) / 2 + + sim_stack = sim_stack ** torch.tensor(betas, device=sim_stack.device) + cs_stack = cs_stack ** torch.tensor(betas, device=cs_stack.device) + return torch.prod(cs_stack[:-1]) * sim_stack[-1] + + +def multiscale_structural_similarity_index_measure( + preds: Tensor, + target: Tensor, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, + k1: float = 0.01, + k2: float = 0.03, + betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + normalize: Optional[Literal["relu", "simple"]] = None, +) -> Tensor: + """Computes `MultiScaleSSIM`_, Multi-scale Structual Similarity Index Measure, which is a generalization of + Structual Similarity Index Measure by incorporating image details at different resolution scores. + + Args: + preds: Predictions from model of shape `[N, C, H, W]` + target: Ground truth values of shape `[N, C, H, W]` + kernel_size: size of the gaussian kernel (default: (11, 11)) + sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + data_range: Range of the image. If ``None``, it is determined from the image (max - min) + k1: Parameter of structural similarity index measure. + k2: Parameter of structural similarity index measure. + betas: Exponent parameters for individual similarities and contrastive sensitivies returned by different image + resolutions. + normalize: When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the + training stability. This `normalize` argument is out of scope of the original implementation [1], and it is + adapted from https://github.com/jorge-pessoa/pytorch-msssim instead. + + Return: + Tensor with Multi-Scale SSIM score + + Raises: + TypeError: + If ``preds`` and ``target`` don't have the same data type. + ValueError: + If ``preds`` and ``target`` don't have ``BxCxHxW shape``. + ValueError: + If the length of ``kernel_size`` or ``sigma`` is not ``2``. + ValueError: + If one of the elements of ``kernel_size`` is not an ``odd positive number``. + ValueError: + If one of the elements of ``sigma`` is not a ``positive number``. + + Example: + >>> from torchmetrics.functional import multiscale_structural_similarity_index_measure + >>> preds = torch.rand([1, 1, 256, 256], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> multiscale_structural_similarity_index_measure(preds, target) + tensor(0.9558) + + References: + [1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C. + Bovik `MultiScaleSSIM`_ + """ + if not isinstance(betas, tuple): + raise ValueError("Argument `betas` is expected to be of a type tuple.") + if isinstance(betas, tuple) and not all(isinstance(beta, float) for beta in betas): + raise ValueError("Argument `betas` is expected to be a tuple of floats.") + if normalize and normalize not in ("relu", "simple"): + raise ValueError("Argument `normalize` to be expected either `None` or one of 'relu' or 'simple'") + + preds, target = _ssim_update(preds, target) + return _multiscale_ssim_compute(preds, target, kernel_size, sigma, reduction, data_range, k1, k2, betas, normalize) diff --git a/torchmetrics/functional/image/ssim.py b/torchmetrics/functional/image/ssim.py index 86b95c5f856..6443e81f92d 100644 --- a/torchmetrics/functional/image/ssim.py +++ b/torchmetrics/functional/image/ssim.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -99,7 +99,8 @@ def _ssim_compute( data_range: Optional[float] = None, k1: float = 0.01, k2: float = 0.03, -) -> Tensor: + return_contrast_sensitivity: bool = False, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Computes Structual Similarity Index Measure. Args: @@ -170,6 +171,11 @@ def _ssim_compute( ssim_idx = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower) ssim_idx = ssim_idx[..., pad_h:-pad_h, pad_w:-pad_w] + if return_contrast_sensitivity: + contrast_sensitivity = upper / lower + contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w] + return reduce(ssim_idx, reduction), reduce(contrast_sensitivity, reduction) + return reduce(ssim_idx, reduction) diff --git a/torchmetrics/functional/regression/__init__.py b/torchmetrics/functional/regression/__init__.py index d33edc2fe1c..d0c767684f1 100644 --- a/torchmetrics/functional/regression/__init__.py +++ b/torchmetrics/functional/regression/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.functional.image.ms_ssim import multiscale_structural_similarity_index_measure # noqa: F401 from torchmetrics.functional.image.psnr import psnr # noqa: F401 from torchmetrics.functional.image.ssim import ssim # noqa: F401 from torchmetrics.functional.regression.cosine_similarity import cosine_similarity # noqa: F401 diff --git a/torchmetrics/image/__init__.py b/torchmetrics/image/__init__.py index b3595139bc6..f0be37d45ea 100644 --- a/torchmetrics/image/__init__.py +++ b/torchmetrics/image/__init__.py @@ -11,5 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.image.ms_ssim import MultiScaleStructuralSimilarityIndexMeasure # noqa: F401 from torchmetrics.image.psnr import PSNR # noqa: F401 from torchmetrics.image.ssim import SSIM # noqa: F401 diff --git a/torchmetrics/image/ms_ssim.py b/torchmetrics/image/ms_ssim.py new file mode 100644 index 00000000000..41ce5f38378 --- /dev/null +++ b/torchmetrics/image/ms_ssim.py @@ -0,0 +1,145 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Sequence, Tuple + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.image.ms_ssim import _multiscale_ssim_compute +from torchmetrics.functional.image.ssim import _ssim_update +from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.data import dim_zero_cat + + +class MultiScaleStructuralSimilarityIndexMeasure(Metric): + """Computes `MultiScaleSSIM`_, Multi-scale Structural Similarity Index Measure, which is a generalization of + Structural Similarity Index Measure by incorporating image details at different resolution scores. + + Args: + kernel_size: size of the gaussian kernel + sigma: Standard deviation of the gaussian kernel + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + data_range: Range of the image. If ``None``, it is determined from the image (max - min) + k1: Parameter of structural similarity index measure. + k2: Parameter of structural similarity index measure. + betas: Exponent parameters for individual similarities and contrastive sensitivies returned by different image + resolutions. + normalize: When MultiScaleStructuralSimilarityIndexMeasure loss is used for training, it is desirable to use + normalizes to improve the training stability. This `normalize` argument is out of scope of the original + implementation [1], and it is adapted from https://github.com/jorge-pessoa/pytorch-msssim instead. + + Return: + Tensor with Multi-Scale SSIM score + + Example: + >>> from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure + >>> preds = torch.rand([1, 1, 256, 256], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure() + >>> ms_ssim(preds, target) + tensor(0.9558) + + References: + [1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C. + Bovik `MultiScaleSSIM`_ + """ + + preds: List[Tensor] + target: List[Tensor] + higher_is_better = True + is_differentiable = True + + def __init__( + self, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: Optional[float] = None, + k1: float = 0.01, + k2: float = 0.03, + betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + normalize: Optional[Literal["relu", "simple"]] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ) -> None: + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + rank_zero_warn( + "Metric `MS_SSIM` will save all targets and" + " predictions in buffer. For large datasets this may lead" + " to large memory footprint." + ) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + all_kernel_ints = all(isinstance(ks, int) for ks in kernel_size) + if not isinstance(kernel_size, Sequence) or len(kernel_size) != 2 or not all_kernel_ints: + raise ValueError( + "Argument `kernel_size` expected to be an sequence of size 2 where each element is an int" + f" but got {kernel_size}" + ) + self.kernel_size = kernel_size + self.sigma = sigma + self.data_range = data_range + self.k1 = k1 + self.k2 = k2 + self.reduction = reduction + if not isinstance(betas, tuple): + raise ValueError("Argument `betas` is expected to be of a type tuple.") + if isinstance(betas, tuple) and not all(isinstance(beta, float) for beta in betas): + raise ValueError("Argument `betas` is expected to be a tuple of floats.") + self.betas = betas + if normalize and normalize not in ("relu", "simple"): + raise ValueError("Argument `normalize` to be expected either `None` or one of 'relu' or 'simple'") + self.normalize = normalize + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Predictions from model of shape `[N, C, H, W]` + target: Ground truth values of shape `[N, C, H, W]` + """ + preds, target = _ssim_update(preds, target) + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Tensor: + """Computes explained variance over state.""" + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + return _multiscale_ssim_compute( + preds, + target, + self.kernel_size, + self.sigma, + self.reduction, + self.data_range, + self.k1, + self.k2, + self.betas, + self.normalize, + )