-
Notifications
You must be signed in to change notification settings - Fork 402
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
MultiScaleStructuralSimilarityIndexMeasure
(#679)
* [WIP] Add MS-SSIM * Fix docs and a mypy issue * fix badge * Apply some SkafteNicki's suggestions from code review * Fix docstring: MS_SSIM -> MultiScaleSSIM * Fix missed MS_SSIM in a docstring * flake8 * Add tests * Update doc example * Fix seed for doc + update some tests * Clean code * Update doc example + make diff test tractable * torch.manual seed passed directly to torch.rand * Apply suggestions from code review * Drop some tests ; fix docs issues + add missing docs * Update paper.md (#690) * ci: rename oldest * CI: set HF caching (#691) * Apply suggestions: ms_ssim -> multiscale_ssim + typing * Update doc reference * Apply SkafteNicki's suggestions * Update test reference package * Clean test * Change naming multiscale_ssim -> multiscale_structual_similarity_index_measure * Update docs references * Fix a typo: structual -> structural * fix one last typo * Fix ~ len in a doc Co-authored-by: Jirka <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Maxim Grechkin <maximsch2@gmail.com> Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
- Loading branch information
1 parent
c20860a
commit 9288222
Showing
14 changed files
with
475 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,3 +35,6 @@ rouge-score>=0.0.4 | |
bert_score==0.3.10 | ||
transformers>=4.0 | ||
sacrebleu>=2.0.0 | ||
|
||
# image | ||
pytorch_msssim |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections import namedtuple | ||
from functools import partial | ||
|
||
import pytest | ||
import torch | ||
from pytorch_msssim import ms_ssim | ||
|
||
from tests.helpers import seed_all | ||
from tests.helpers.testers import NUM_BATCHES, MetricTester | ||
from torchmetrics.functional.image.ms_ssim import multiscale_structural_similarity_index_measure | ||
from torchmetrics.image.ms_ssim import MultiScaleStructuralSimilarityIndexMeasure | ||
|
||
seed_all(42) | ||
|
||
Input = namedtuple("Input", ["preds", "target"]) | ||
BATCH_SIZE = 1 | ||
|
||
_inputs = [] | ||
for size, coef in [(128, 0.9), (128, 0.7)]: | ||
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, 1, size, size) | ||
_inputs.append( | ||
Input( | ||
preds=preds, | ||
target=preds * coef, | ||
) | ||
) | ||
|
||
|
||
def pytorch_ms_ssim(preds, target, data_range, kernel_size): | ||
return ms_ssim(preds, target, data_range=data_range, win_size=kernel_size) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"preds, target", | ||
[(i.preds, i.target) for i in _inputs], | ||
) | ||
@pytest.mark.parametrize("kernel_size", [5, 7]) | ||
class TestMultiScaleStructuralSimilarityIndexMeasure(MetricTester): | ||
atol = 6e-3 | ||
|
||
@pytest.mark.parametrize("ddp", [False, True]) | ||
@pytest.mark.parametrize("dist_sync_on_step", [False, True]) | ||
def test_ms_ssim(self, preds, target, kernel_size, ddp, dist_sync_on_step): | ||
self.run_class_metric_test( | ||
ddp, | ||
preds, | ||
target, | ||
MultiScaleStructuralSimilarityIndexMeasure, | ||
partial(pytorch_ms_ssim, data_range=1.0, kernel_size=kernel_size), | ||
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)}, | ||
dist_sync_on_step=dist_sync_on_step, | ||
) | ||
|
||
def test_ms_ssim_functional(self, preds, target, kernel_size): | ||
self.run_functional_metric_test( | ||
preds, | ||
target, | ||
multiscale_structural_similarity_index_measure, | ||
partial(pytorch_ms_ssim, data_range=1.0, kernel_size=kernel_size), | ||
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)}, | ||
) | ||
|
||
def test_ms_ssim_differentiability(self, preds, target, kernel_size): | ||
# We need to minimize this example to make the test tractable | ||
single_beta = (1.0,) | ||
_preds = preds[:, :, :, :16, :16] | ||
_target = target[:, :, :, :16, :16] | ||
|
||
self.run_differentiability_test( | ||
_preds.type(torch.float64), | ||
_target.type(torch.float64), | ||
metric_functional=multiscale_structural_similarity_index_measure, | ||
metric_module=MultiScaleStructuralSimilarityIndexMeasure, | ||
metric_args={ | ||
"data_range": 1.0, | ||
"kernel_size": (kernel_size, kernel_size), | ||
"betas": single_beta, | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import List, Optional, Sequence, Tuple, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import Tensor | ||
from typing_extensions import Literal | ||
|
||
from torchmetrics.functional.image.ssim import _ssim_compute, _ssim_update | ||
|
||
|
||
def _get_normalized_sim_and_cs( | ||
preds: Tensor, | ||
target: Tensor, | ||
kernel_size: Sequence[int] = (11, 11), | ||
sigma: Sequence[float] = (1.5, 1.5), | ||
reduction: str = "elementwise_mean", | ||
data_range: Optional[float] = None, | ||
k1: float = 0.01, | ||
k2: float = 0.03, | ||
normalize: Optional[Literal["relu", "simple"]] = None, | ||
) -> Tuple[Tensor, Tensor]: | ||
sim, contrast_sensitivity = _ssim_compute( | ||
preds, target, kernel_size, sigma, reduction, data_range, k1, k2, return_contrast_sensitivity=True | ||
) | ||
if normalize == "relu": | ||
sim = torch.relu(sim) | ||
contrast_sensitivity = torch.relu(contrast_sensitivity) | ||
return sim, contrast_sensitivity | ||
|
||
|
||
def _multiscale_ssim_compute( | ||
preds: Tensor, | ||
target: Tensor, | ||
kernel_size: Sequence[int] = (11, 11), | ||
sigma: Sequence[float] = (1.5, 1.5), | ||
reduction: str = "elementwise_mean", | ||
data_range: Optional[float] = None, | ||
k1: float = 0.01, | ||
k2: float = 0.03, | ||
betas: Union[Tuple[float, float, float, float, float], Tuple[float, ...]] = ( | ||
0.0448, | ||
0.2856, | ||
0.3001, | ||
0.2363, | ||
0.1333, | ||
), | ||
normalize: Optional[Literal["relu", "simple"]] = None, | ||
) -> Tensor: | ||
"""Computes Multi-Scale Structual Similarity Index Measure. Adapted from: https://github.com/jorge- | ||
pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py. | ||
Args: | ||
preds: estimated image | ||
target: ground truth image | ||
kernel_size: size of the gaussian kernel (default: (11, 11)) | ||
sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) | ||
reduction: a method to reduce metric score over labels. | ||
- ``'elementwise_mean'``: takes the mean (default) | ||
- ``'sum'``: takes the sum | ||
- ``'none'``: no reduction will be applied | ||
data_range: Range of the image. If ``None``, it is determined from the image (max - min) | ||
k1: Parameter of structural similarity index measure. | ||
k2: Parameter of structural similarity index measure. | ||
betas: Exponent parameters for individual similarities and contrastive sensitivies returned by different image | ||
resolutions. | ||
normalize: When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the | ||
training stability. This `normalize` argument is out of scope of the original implementation [1], and it is | ||
adapted from https://github.com/jorge-pessoa/pytorch-msssim instead. | ||
Raises: | ||
ValueError: | ||
If the image height or width is smaller then ``2 ** len(betas)``. | ||
ValueError: | ||
If the image height is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``. | ||
ValueError: | ||
If the image width is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``. | ||
""" | ||
sim_list: List[Tensor] = [] | ||
cs_list: List[Tensor] = [] | ||
|
||
if preds.size()[-1] < 2 ** len(betas) or preds.size()[-2] < 2 ** len(betas): | ||
raise ValueError( | ||
f"For a given number of `betas` parameters {len(betas)}, the image height and width dimensions must be" | ||
f" larger than or equal to {2 ** len(betas)}." | ||
) | ||
|
||
_betas_div = max(1, (len(betas) - 1)) ** 2 | ||
if preds.size()[-2] // _betas_div <= kernel_size[0] - 1: | ||
raise ValueError( | ||
f"For a given number of `betas` parameters {len(betas)} and kernel size {kernel_size[0]}," | ||
f" the image height must be larger than {(kernel_size[0] - 1) * _betas_div}." | ||
) | ||
if preds.size()[-1] // _betas_div <= kernel_size[1] - 1: | ||
raise ValueError( | ||
f"For a given number of `betas` parameters {len(betas)} and kernel size {kernel_size[1]}," | ||
f" the image width must be larger than {(kernel_size[1] - 1) * _betas_div}." | ||
) | ||
|
||
for _ in range(len(betas)): | ||
sim, contrast_sensitivity = _get_normalized_sim_and_cs( | ||
preds, target, kernel_size, sigma, reduction, data_range, k1, k2, normalize | ||
) | ||
sim_list.append(sim) | ||
cs_list.append(contrast_sensitivity) | ||
preds = F.avg_pool2d(preds, (2, 2)) | ||
target = F.avg_pool2d(target, (2, 2)) | ||
|
||
sim_stack = torch.stack(sim_list) | ||
cs_stack = torch.stack(cs_list) | ||
|
||
if normalize == "simple": | ||
sim_stack = (sim_stack + 1) / 2 | ||
cs_stack = (cs_stack + 1) / 2 | ||
|
||
sim_stack = sim_stack ** torch.tensor(betas, device=sim_stack.device) | ||
cs_stack = cs_stack ** torch.tensor(betas, device=cs_stack.device) | ||
return torch.prod(cs_stack[:-1]) * sim_stack[-1] | ||
|
||
|
||
def multiscale_structural_similarity_index_measure( | ||
preds: Tensor, | ||
target: Tensor, | ||
kernel_size: Sequence[int] = (11, 11), | ||
sigma: Sequence[float] = (1.5, 1.5), | ||
reduction: str = "elementwise_mean", | ||
data_range: Optional[float] = None, | ||
k1: float = 0.01, | ||
k2: float = 0.03, | ||
betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), | ||
normalize: Optional[Literal["relu", "simple"]] = None, | ||
) -> Tensor: | ||
"""Computes `MultiScaleSSIM`_, Multi-scale Structual Similarity Index Measure, which is a generalization of | ||
Structual Similarity Index Measure by incorporating image details at different resolution scores. | ||
Args: | ||
preds: Predictions from model of shape `[N, C, H, W]` | ||
target: Ground truth values of shape `[N, C, H, W]` | ||
kernel_size: size of the gaussian kernel (default: (11, 11)) | ||
sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) | ||
reduction: a method to reduce metric score over labels. | ||
- ``'elementwise_mean'``: takes the mean (default) | ||
- ``'sum'``: takes the sum | ||
- ``'none'``: no reduction will be applied | ||
data_range: Range of the image. If ``None``, it is determined from the image (max - min) | ||
k1: Parameter of structural similarity index measure. | ||
k2: Parameter of structural similarity index measure. | ||
betas: Exponent parameters for individual similarities and contrastive sensitivies returned by different image | ||
resolutions. | ||
normalize: When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the | ||
training stability. This `normalize` argument is out of scope of the original implementation [1], and it is | ||
adapted from https://github.com/jorge-pessoa/pytorch-msssim instead. | ||
Return: | ||
Tensor with Multi-Scale SSIM score | ||
Raises: | ||
TypeError: | ||
If ``preds`` and ``target`` don't have the same data type. | ||
ValueError: | ||
If ``preds`` and ``target`` don't have ``BxCxHxW shape``. | ||
ValueError: | ||
If the length of ``kernel_size`` or ``sigma`` is not ``2``. | ||
ValueError: | ||
If one of the elements of ``kernel_size`` is not an ``odd positive number``. | ||
ValueError: | ||
If one of the elements of ``sigma`` is not a ``positive number``. | ||
Example: | ||
>>> from torchmetrics.functional import multiscale_structural_similarity_index_measure | ||
>>> preds = torch.rand([1, 1, 256, 256], generator=torch.manual_seed(42)) | ||
>>> target = preds * 0.75 | ||
>>> multiscale_structural_similarity_index_measure(preds, target) | ||
tensor(0.9558) | ||
References: | ||
[1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C. | ||
Bovik `MultiScaleSSIM`_ | ||
""" | ||
if not isinstance(betas, tuple): | ||
raise ValueError("Argument `betas` is expected to be of a type tuple.") | ||
if isinstance(betas, tuple) and not all(isinstance(beta, float) for beta in betas): | ||
raise ValueError("Argument `betas` is expected to be a tuple of floats.") | ||
if normalize and normalize not in ("relu", "simple"): | ||
raise ValueError("Argument `normalize` to be expected either `None` or one of 'relu' or 'simple'") | ||
|
||
preds, target = _ssim_update(preds, target) | ||
return _multiscale_ssim_compute(preds, target, kernel_size, sigma, reduction, data_range, k1, k2, betas, normalize) |
Oops, something went wrong.