Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Online SSIM and MS-SSIM Computation #1231

Merged
merged 21 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 19 additions & 25 deletions src/torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def structural_similarity_index_measure(

Example:
>>> from torchmetrics.functional import structural_similarity_index_measure
>>> preds = torch.rand([16, 1, 16, 16])
>>> preds = torch.rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> structural_similarity_index_measure(preds, target)
tensor(0.9219)
Expand Down Expand Up @@ -354,8 +354,7 @@ def _multiscale_ssim_compute(
ValueError:
If the image width is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``.
"""
sim_list: List[Tensor] = []
cs_list: List[Tensor] = []
mcs_list: List[Tensor] = []

is_3d = len(preds.shape) == 5

Expand Down Expand Up @@ -384,10 +383,10 @@ def _multiscale_ssim_compute(

for _ in range(len(betas)):
sim, contrast_sensitivity = _get_normalized_sim_and_cs(
preds, target, gaussian_kernel, sigma, kernel_size, reduction, data_range, k1, k2, normalize=normalize
preds, target, gaussian_kernel, sigma, kernel_size, None, data_range, k1, k2, normalize=normalize
)
sim_list.append(sim)
cs_list.append(contrast_sensitivity)
mcs_list.append(contrast_sensitivity)

if len(kernel_size) == 2:
preds = F.avg_pool2d(preds, (2, 2))
target = F.avg_pool2d(target, (2, 2))
Expand All @@ -396,23 +395,18 @@ def _multiscale_ssim_compute(
target = F.avg_pool3d(target, (2, 2, 2))
else:
raise ValueError("length of kernel_size is neither 2 nor 3")
sim_stack = torch.stack(sim_list)
cs_stack = torch.stack(cs_list)

mcs_list[-1] = sim
mcs_stack = torch.stack(mcs_list)

if normalize == "simple":
sim_stack = (sim_stack + 1) / 2
cs_stack = (cs_stack + 1) / 2

if reduction is None or reduction == "none":
betas = torch.tensor(betas).unsqueeze(1).repeat(1, sim_stack.shape[0])
sim_stack = sim_stack ** torch.tensor(betas, device=sim_stack.device)
cs_stack = cs_stack ** torch.tensor(betas, device=cs_stack.device)
cs_and_sim = torch.cat((cs_stack[:-1], sim_stack[-1:]), axis=0)
return torch.prod(cs_and_sim, axis=0)
else:
sim_stack = sim_stack ** torch.tensor(betas, device=sim_stack.device)
cs_stack = cs_stack ** torch.tensor(betas, device=cs_stack.device)
return torch.prod(cs_stack[:-1]) * sim_stack[-1]
mcs_stack = (mcs_stack + 1) / 2

betas = torch.tensor(betas, device=mcs_stack.device).view(-1, 1)
mcs_weighted = mcs_stack**betas
mcs_per_image = torch.prod(mcs_weighted, axis=0)

return reduce(mcs_per_image, reduction)


def multiscale_structural_similarity_index_measure(
Expand All @@ -426,7 +420,7 @@ def multiscale_structural_similarity_index_measure(
k1: float = 0.01,
k2: float = 0.03,
betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
normalize: Optional[Literal["relu", "simple"]] = None,
normalize: Optional[Literal["relu", "simple"]] = "relu",
) -> Tensor:
"""Computes `MultiScaleSSIM`_, Multi-scale Structual Similarity Index Measure, which is a generalization of
Structual Similarity Index Measure by incorporating image details at different resolution scores.
Expand Down Expand Up @@ -468,10 +462,10 @@ def multiscale_structural_similarity_index_measure(

Example:
>>> from torchmetrics.functional import multiscale_structural_similarity_index_measure
>>> preds = torch.rand([1, 1, 256, 256], generator=torch.manual_seed(42))
>>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> multiscale_structural_similarity_index_measure(preds, target)
tensor(0.9558)
>>> multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
justusschock marked this conversation as resolved.
Show resolved Hide resolved
tensor(0.9627)

References:
[1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C.
Expand Down
120 changes: 81 additions & 39 deletions src/torchmetrics/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
from typing import Any, List, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.image.ssim import _multiscale_ssim_compute, _ssim_compute, _ssim_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat


Expand Down Expand Up @@ -55,9 +55,9 @@ class StructuralSimilarityIndexMeasure(Metric):
Example:
>>> from torchmetrics import StructuralSimilarityIndexMeasure
>>> import torch
>>> preds = torch.rand([16, 1, 16, 16])
>>> preds = torch.rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> ssim = StructuralSimilarityIndexMeasure()
>>> ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
justusschock marked this conversation as resolved.
Show resolved Hide resolved
>>> ssim(preds, target)
tensor(0.9219)
"""
Expand All @@ -83,14 +83,21 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
rank_zero_warn(
"Metric `SSIM` will save all targets and"
" predictions in buffer. For large datasets this may lead"
" to large memory footprint."
)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
valid_reduction = ("elementwise_mean", "sum", "none", None)
if reduction not in valid_reduction:
raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}")

if reduction in ("elementwise_mean", "sum"):
self.add_state("similarity", default=torch.tensor(0.0), dist_reduce_fx="sum")
else:
self.add_state("similarity", default=[], dist_reduce_fx="cat")

self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")

if return_contrast_sensitivity or return_full_image:
self.add_state("image_return", default=[], dist_reduce_fx="cat")

self.gaussian_kernel = gaussian_kernel
self.sigma = sigma
self.kernel_size = kernel_size
Expand All @@ -109,27 +116,49 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
target: Ground truth values
"""
preds, target = _ssim_update(preds, target)
self.preds.append(preds)
self.target.append(target)

def compute(self) -> Tensor:
"""Computes explained variance over state."""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _ssim_compute(
similarity_pack = _ssim_compute(
preds,
target,
self.gaussian_kernel,
self.sigma,
self.kernel_size,
self.reduction,
None,
self.data_range,
self.k1,
self.k2,
self.return_full_image,
self.return_contrast_sensitivity,
)

if isinstance(similarity_pack, tuple):
similarity, image = similarity_pack
else:
similarity = similarity_pack

if self.return_contrast_sensitivity or self.return_full_image:
self.image_return.append(image)

if self.reduction in ("elementwise_mean", "sum"):
self.similarity += similarity.sum()
self.total += preds.shape[0]
else:
self.similarity.append(similarity)

def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Computes SSIM over state."""
if self.reduction == "elementwise_mean":
similarity = self.similarity / self.total
elif self.reduction == "sum":
similarity = self.similarity
else:
similarity = dim_zero_cat(self.similarity)

if self.return_contrast_sensitivity or self.return_full_image:
image_return = dim_zero_cat(self.image_return)
return similarity, image_return

return similarity


class MultiScaleStructuralSimilarityIndexMeasure(Metric):
"""Computes `MultiScaleSSIM`_, Multi-scale Structural Similarity Index Measure, which is a generalization of
Expand Down Expand Up @@ -169,11 +198,11 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric):
Example:
>>> from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure
>>> import torch
>>> preds = torch.rand([1, 1, 256, 256], generator=torch.manual_seed(42))
>>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure()
>>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> ms_ssim(preds, target)
tensor(0.9558)
tensor(0.9627)

References:
[1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C.
Expand All @@ -197,18 +226,21 @@ def __init__(
k1: float = 0.01,
k2: float = 0.03,
betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
normalize: Literal["relu", "simple", None] = None,
normalize: Literal["relu", "simple", None] = "relu",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
rank_zero_warn(
"Metric `MS_SSIM` will save all targets and"
" predictions in buffer. For large datasets this may lead"
" to large memory footprint."
)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
valid_reduction = ("elementwise_mean", "sum", "none", None)
if reduction not in valid_reduction:
raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}")

if reduction in ("elementwise_mean", "sum"):
self.add_state("similarity", default=torch.tensor(0.0), dist_reduce_fx="sum")
else:
self.add_state("similarity", default=[], dist_reduce_fx="cat")

self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")

if not (isinstance(kernel_size, (Sequence, int))):
raise ValueError(
Expand Down Expand Up @@ -246,23 +278,33 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
target: Ground truth values of shape ``[N, C, H, W]``
"""
preds, target = _ssim_update(preds, target)
self.preds.append(preds)
self.target.append(target)

def compute(self) -> Tensor:
"""Computes explained variance over state."""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _multiscale_ssim_compute(
similarity = _multiscale_ssim_compute(
preds,
target,
self.gaussian_kernel,
self.sigma,
self.kernel_size,
self.reduction,
None,
self.data_range,
self.k1,
self.k2,
self.betas,
self.normalize,
)

if self.reduction in ("none", None):
self.similarity.append(similarity)
else:
self.similarity += similarity.sum()

self.total += preds.shape[0]

def compute(self) -> Tensor:
"""Computes MS-SSIM over state."""

if self.reduction in ("none", None):
return dim_zero_cat(self.similarity)
elif self.reduction == "sum":
return self.similarity
else:
return self.similarity / self.total