Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ Complete list of metrics
- :class:`~ignite.metrics.Recall`
- :class:`~ignite.metrics.RootMeanSquaredError`
- :class:`~ignite.metrics.RunningAverage`
- :class:`~ignite.metrics.SSIM`
- :class:`~ignite.metrics.TopKCategoricalAccuracy`
- :class:`~ignite.metrics.VariableAccumulation`

Expand Down Expand Up @@ -278,6 +279,8 @@ Complete list of metrics

.. autoclass:: RunningAverage

.. autoclass:: SSIM

.. autoclass:: TopKCategoricalAccuracy

.. autoclass:: VariableAccumulation
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ignite.metrics.recall import Recall
from ignite.metrics.root_mean_squared_error import RootMeanSquaredError
from ignite.metrics.running_average import RunningAverage
from ignite.metrics.ssim import SSIM
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy

__all__ = [
Expand All @@ -39,4 +40,5 @@
"RunningAverage",
"VariableAccumulation",
"Frequency",
"SSIM",
]
149 changes: 149 additions & 0 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import Callable, Sequence, Union

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["SSIM"]


class SSIM(Metric):
"""
Computes Structual Similarity Index Measure

Args:
kernel_size (int or list or tuple of int): Size of the gaussian kernel. Default: (11, 11)
sigma (float or list or tuple of float): Standard deviation of the gaussian kernel. Default: (1.5, 1.5)
data_range (int or float): Range of the image. Typically, ``1.0`` or ``255``.
k1 (float): Parameter of SSIM. Default: 0.01
k2 (float): Parameter of SSIM. Default: 0.03
output_transform (callable, optional): A callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric.

Example:

To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.

``y_pred`` and ``y`` can be un-normalized or normalized image tensors. Depending on that, the user might need
to adjust ``data_range``. ``y_pred`` and ``y`` should have the same shape.

.. code-block:: python

def process_function(engine, batch):
# ...
return y_pred, y
engine = Engine(process_function)
metric = SSIM(data_range=1.0)
metric.attach(engine, "ssim")
"""

def __init__(
self,
data_range: Union[int, float],
kernel_size: Union[int, Sequence[int]] = (11, 11),
sigma: Union[float, Sequence[float]] = (1.5, 1.5),
k1: float = 0.01,
k2: float = 0.03,
output_transform: Callable = lambda x: x,
):
if isinstance(kernel_size, int):
self.kernel_size = [kernel_size, kernel_size]
elif isinstance(kernel_size, Sequence):
self.kernel_size = kernel_size
else:
raise ValueError("Argument kernel_size should be either int or a sequence of int.")

if isinstance(sigma, float):
self.sigma = [sigma, sigma]
elif isinstance(sigma, Sequence):
self.sigma = sigma
else:
raise ValueError("Argument sigma should be either float or a sequence of float.")

if any(x % 2 == 0 or x <= 0 for x in self.kernel_size):
raise ValueError("Expected kernel_size to have odd positive number. Got {}.".format(kernel_size))

if any(y <= 0 for y in self.sigma):
raise ValueError("Expected sigma to have positive number. Got {}.".format(sigma))

self.data_range = data_range
self.k1 = k1
self.k2 = k2
self._kernel = self._gaussian_kernel(kernel_size=self.kernel_size, sigma=self.sigma)
super(SSIM, self).__init__(output_transform=output_transform)

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_batchwise_ssim = 0.0
self._num_examples = 0
self._kernel = self._gaussian_kernel(kernel_size=self.kernel_size, sigma=self.sigma)

def _gaussian(self, kernel_size, sigma):
gauss = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32)
gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2)))
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)

def _gaussian_kernel(self, kernel_size, sigma):
gaussian_kernel_x = self._gaussian(kernel_size[0], sigma[0])
gaussian_kernel_y = self._gaussian(kernel_size[1], sigma[1])

return torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
if y_pred.dtype != y.dtype:
raise TypeError(
"Expected y_pred and y to have the same data type. Got y_pred: {} and y: {}.".format(
y_pred.dtype, y.dtype
)
)

if y_pred.shape != y.shape:
raise ValueError(
"Expected y_pred and y to have the same shape. Got y_pred: {} and y: {}.".format(y_pred.shape, y.shape)
)

if len(y_pred.shape) != 4 or len(y.shape) != 4:
raise ValueError(
"Expected y_pred and y to have BxCxHxW shape. Got y_pred: {} and y: {}.".format(y_pred.shape, y.shape)
)

c1 = (self.k1 * self.data_range) ** 2
c2 = (self.k2 * self.data_range) ** 2

channel = y_pred.size(1)
if len(self._kernel.shape) < 4:
self._kernel = self._kernel.expand(channel, 1, -1, -1).to(device=y_pred.device)

input_list = [y_pred, y, y_pred * y_pred, y * y, y_pred * y]
output_list = [F.conv2d(x, self._kernel, groups=channel) for x in input_list]

mu_pred_sq = output_list[0].pow(2)
mu_target_sq = output_list[1].pow(2)
mu_pred_target = output_list[0] * output_list[1]

sigma_pred_sq = output_list[2] - mu_pred_sq
sigma_target_sq = output_list[3] - mu_target_sq
sigma_pred_target = output_list[4] - mu_pred_target

a1 = 2 * mu_pred_target + c1
a2 = 2 * sigma_pred_target + c2
b1 = mu_pred_sq + mu_target_sq + c1
b2 = sigma_pred_sq + sigma_target_sq + c2

ssim_idx = (a1 * a2) / (b1 * b2)

self._sum_of_batchwise_ssim += torch.mean(ssim_idx, (1, 2, 3))
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_batchwise_ssim", "_num_examples")
def compute(self) -> torch.Tensor:
if self._num_examples == 0:
raise NotComputableError("SSIM must have at least one example before it can be computed.")
return torch.sum(self._sum_of_batchwise_ssim / self._num_examples)
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ neptune-client
tensorboard
pynvml; python_version > '3.5'
trains>=0.15.1
scikit-image>=0.15.0
# Examples dependencies
pandas
gym
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ignore = E402, E721
max_line_length = 120

[isort]
known_third_party=dill,matplotlib,numpy,pytest,setuptools,sklearn,torch,torchvision,trains
known_third_party=dill,matplotlib,numpy,pytest,setuptools,skimage,sklearn,torch,torchvision,trains
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
145 changes: 145 additions & 0 deletions tests/ignite/metrics/test_ssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import os

import pytest
import torch
from skimage.measure import compare_ssim as ski_ssim

import ignite.distributed as idist
from ignite.exceptions import NotComputableError
from ignite.metrics import SSIM


def test_zero_div():
ssim = SSIM(data_range=1.0)
with pytest.raises(NotComputableError):
ssim.compute()


def test_invalid_ssim():
y_pred = torch.rand(16, 1, 32, 32)
y = y_pred + 0.125
with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got 10."):
ssim = SSIM(data_range=1.0, kernel_size=10)
ssim.update((y_pred, y))
ssim.compute()

with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got -1."):
ssim = SSIM(data_range=1.0, kernel_size=-1)
ssim.update((y_pred, y))
ssim.compute()

with pytest.raises(ValueError, match=r"Argument kernel_size should be either int or a sequence of int."):
ssim = SSIM(data_range=1.0, kernel_size=1.0)
ssim.update((y_pred, y))
ssim.compute()

with pytest.raises(ValueError, match=r"Argument sigma should be either float or a sequence of float."):
ssim = SSIM(data_range=1.0, sigma=-1)
ssim.update((y_pred, y))
ssim.compute()

with pytest.raises(ValueError, match=r"Argument sigma should be either float or a sequence of float."):
ssim = SSIM(data_range=1.0, sigma=1)
ssim.update((y_pred, y))
ssim.compute()


def test_ssim():
ssim = SSIM(data_range=1.0)
device = "cuda" if torch.cuda.is_available() else "cpu"
y_pred = torch.rand(16, 3, 32, 32, device=device)
y = y_pred * 0.65
ssim.update((y_pred, y))

np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
np_y = np_pred * 0.65
np_ssim = ski_ssim(np_pred, np_y, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0)

assert isinstance(ssim.compute(), torch.Tensor)
assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float32, device=device))


def _test_distrib_integration(device, tol=1e-6):
from ignite.engine import Engine

rank = idist.get_rank()
n_iters = 100
s = 10
offset = n_iters * s

y_pred = torch.rand(offset * idist.get_world_size(), 3, 28, 28, dtype=torch.float, device=device)
y = y_pred * 0.65

def update(engine, i):
return (
y_pred[i * s + offset * rank : (i + 1) * s + offset * rank],
y[i * s + offset * rank : (i + 1) * s + offset * rank],
)

engine = Engine(update)
SSIM(data_range=1.0).attach(engine, "ssim")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

assert "ssim" in engine.state.metrics
res = engine.state.metrics["ssim"]

np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
np_true = np_pred * 0.65
true_res = ski_ssim(np_pred, np_true, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0)

assert pytest.approx(res, rel=tol) == true_res


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_gpu(local_rank, distributed_context_single_node_nccl):

device = "cuda:{}".format(local_rank)
_test_distrib_integration(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_cpu(distributed_context_single_node_gloo):
device = "cpu"
_test_distrib_integration(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_cpu(distributed_context_multi_node_gloo):
device = "cpu"
_test_distrib_integration(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gpu(distributed_context_multi_node_nccl):
device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])
_test_distrib_integration(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_integration(device, tol=1e-3)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_integration(device, tol=1e-3)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)