diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 92dd56d947f7..bc2da2e12cba 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -234,6 +234,7 @@ Complete list of metrics - :class:`~ignite.metrics.Recall` - :class:`~ignite.metrics.RootMeanSquaredError` - :class:`~ignite.metrics.RunningAverage` + - :class:`~ignite.metrics.SSIM` - :class:`~ignite.metrics.TopKCategoricalAccuracy` - :class:`~ignite.metrics.VariableAccumulation` @@ -278,6 +279,8 @@ Complete list of metrics .. autoclass:: RunningAverage +.. autoclass:: SSIM + .. autoclass:: TopKCategoricalAccuracy .. autoclass:: VariableAccumulation diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 5f4a3978ae99..de0415bbe6c6 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -14,6 +14,7 @@ from ignite.metrics.recall import Recall from ignite.metrics.root_mean_squared_error import RootMeanSquaredError from ignite.metrics.running_average import RunningAverage +from ignite.metrics.ssim import SSIM from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy __all__ = [ @@ -39,4 +40,5 @@ "RunningAverage", "VariableAccumulation", "Frequency", + "SSIM", ] diff --git a/ignite/metrics/ssim.py b/ignite/metrics/ssim.py new file mode 100644 index 000000000000..91491432db37 --- /dev/null +++ b/ignite/metrics/ssim.py @@ -0,0 +1,170 @@ +from typing import Callable, Sequence, Union + +import torch +import torch.nn.functional as F + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["SSIM"] + + +class SSIM(Metric): + """ + Computes Structual Similarity Index Measure + + Args: + data_range (int or float): Range of the image. Typically, ``1.0`` or ``255``. + kernel_size (int or list or tuple of int): Size of the kernel. Default: (11, 11) + sigma (float or list or tuple of float): Standard deviation of the gaussian kernel. + Argument is used if ``gaussian=True``. Default: (1.5, 1.5) + k1 (float): Parameter of SSIM. Default: 0.01 + k2 (float): Parameter of SSIM. Default: 0.03 + gaussian (bool): ``True`` to use gaussian kernel, ``False`` to use uniform kernel + output_transform (callable, optional): A callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. + + Example: + + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. + + ``y_pred`` and ``y`` can be un-normalized or normalized image tensors. Depending on that, the user might need + to adjust ``data_range``. ``y_pred`` and ``y`` should have the same shape. + + .. code-block:: python + + def process_function(engine, batch): + # ... + return y_pred, y + engine = Engine(process_function) + metric = SSIM(data_range=1.0) + metric.attach(engine, "ssim") + """ + + def __init__( + self, + data_range: Union[int, float], + kernel_size: Union[int, Sequence[int]] = (11, 11), + sigma: Union[float, Sequence[float]] = (1.5, 1.5), + k1: float = 0.01, + k2: float = 0.03, + gaussian: bool = True, + output_transform: Callable = lambda x: x, + ): + if isinstance(kernel_size, int): + self.kernel_size = [kernel_size, kernel_size] + elif isinstance(kernel_size, Sequence): + self.kernel_size = kernel_size + else: + raise ValueError("Argument kernel_size should be either int or a sequence of int.") + + if isinstance(sigma, float): + self.sigma = [sigma, sigma] + elif isinstance(sigma, Sequence): + self.sigma = sigma + else: + raise ValueError("Argument sigma should be either float or a sequence of float.") + + if any(x % 2 == 0 or x <= 0 for x in self.kernel_size): + raise ValueError("Expected kernel_size to have odd positive number. Got {}.".format(kernel_size)) + + if any(y <= 0 for y in self.sigma): + raise ValueError("Expected sigma to have positive number. Got {}.".format(sigma)) + + self.gaussian = gaussian + self.c1 = (k1 * data_range) ** 2 + self.c2 = (k2 * data_range) ** 2 + self.pad_h = (self.kernel_size[0] - 1) // 2 + self.pad_w = (self.kernel_size[1] - 1) // 2 + self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma) + super(SSIM, self).__init__(output_transform=output_transform) + + @reinit__is_reduced + def reset(self) -> None: + self._sum_of_batchwise_ssim = 0.0 + self._num_examples = 0 + self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma) + + def _uniform(self, kernel_size): + max, min = 2.5, -2.5 + kernel = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32) + for i, j in enumerate(kernel): + if min <= j <= max: + kernel[i] = 1 / (max - min) + else: + kernel[i] = 0 + + return kernel.unsqueeze(dim=0) # (1, kernel_size) + + def _gaussian(self, kernel_size, sigma): + kernel = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32) + gauss = torch.exp(-kernel.pow(2) / (2 * pow(sigma, 2))) + return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) + + def _gaussian_or_uniform_kernel(self, kernel_size, sigma): + if self.gaussian: + kernel_x = self._gaussian(kernel_size[0], sigma[0]) + kernel_y = self._gaussian(kernel_size[1], sigma[1]) + else: + kernel_x = self._uniform(kernel_size[0]) + kernel_y = self._uniform(kernel_size[1]) + + return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size) + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + y_pred, y = output + if y_pred.dtype != y.dtype: + raise TypeError( + "Expected y_pred and y to have the same data type. Got y_pred: {} and y: {}.".format( + y_pred.dtype, y.dtype + ) + ) + + if y_pred.shape != y.shape: + raise ValueError( + "Expected y_pred and y to have the same shape. Got y_pred: {} and y: {}.".format(y_pred.shape, y.shape) + ) + + if len(y_pred.shape) != 4 or len(y.shape) != 4: + raise ValueError( + "Expected y_pred and y to have BxCxHxW shape. Got y_pred: {} and y: {}.".format(y_pred.shape, y.shape) + ) + + channel = y_pred.size(1) + if len(self._kernel.shape) < 4: + self._kernel = self._kernel.expand(channel, 1, -1, -1).to(device=y_pred.device) + + y_pred = F.pad(y_pred, (self.pad_w, self.pad_w, self.pad_h, self.pad_h), mode="reflect") + y = F.pad(y, (self.pad_w, self.pad_w, self.pad_h, self.pad_h), mode="reflect") + + input_list = torch.cat([y_pred, y, y_pred * y_pred, y * y, y_pred * y]) + outputs = F.conv2d(input_list, self._kernel, groups=channel) + + output_list = [outputs[x * y_pred.size(0) : (x + 1) * y_pred.size(0)] for x in range(len(outputs))] + + mu_pred_sq = output_list[0].pow(2) + mu_target_sq = output_list[1].pow(2) + mu_pred_target = output_list[0] * output_list[1] + + sigma_pred_sq = output_list[2] - mu_pred_sq + sigma_target_sq = output_list[3] - mu_target_sq + sigma_pred_target = output_list[4] - mu_pred_target + + a1 = 2 * mu_pred_target + self.c1 + a2 = 2 * sigma_pred_target + self.c2 + b1 = mu_pred_sq + mu_target_sq + self.c1 + b2 = sigma_pred_sq + sigma_target_sq + self.c2 + + ssim_idx = (a1 * a2) / (b1 * b2) + self._sum_of_batchwise_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64) + self._num_examples += y.shape[0] + + @sync_all_reduce("_sum_of_batchwise_ssim", "_num_examples") + def compute(self) -> torch.Tensor: + if self._num_examples == 0: + raise NotComputableError("SSIM must have at least one example before it can be computed.") + return torch.sum(self._sum_of_batchwise_ssim / self._num_examples) diff --git a/requirements-dev.txt b/requirements-dev.txt index 47b63d5f3291..2e04e363f873 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,6 +18,7 @@ neptune-client tensorboard pynvml; python_version > '3.5' trains>=0.15.1 +scikit-image>=0.15.0 # Examples dependencies pandas gym diff --git a/tests/ignite/metrics/test_ssim.py b/tests/ignite/metrics/test_ssim.py new file mode 100644 index 000000000000..77197629865d --- /dev/null +++ b/tests/ignite/metrics/test_ssim.py @@ -0,0 +1,177 @@ +import os + +import pytest +import torch + +import ignite.distributed as idist +from ignite.exceptions import NotComputableError +from ignite.metrics import SSIM + +try: + from skimage.metrics import structural_similarity as ski_ssim +except ImportError: + from skimage.measure import compare_ssim as ski_ssim + + +def test_zero_div(): + ssim = SSIM(data_range=1.0) + with pytest.raises(NotComputableError): + ssim.compute() + + +def test_invalid_ssim(): + y_pred = torch.rand(16, 1, 32, 32) + y = y_pred + 0.125 + with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got 10."): + ssim = SSIM(data_range=1.0, kernel_size=10) + ssim.update((y_pred, y)) + ssim.compute() + + with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got -1."): + ssim = SSIM(data_range=1.0, kernel_size=-1) + ssim.update((y_pred, y)) + ssim.compute() + + with pytest.raises(ValueError, match=r"Argument kernel_size should be either int or a sequence of int."): + ssim = SSIM(data_range=1.0, kernel_size=1.0) + ssim.update((y_pred, y)) + ssim.compute() + + with pytest.raises(ValueError, match=r"Argument sigma should be either float or a sequence of float."): + ssim = SSIM(data_range=1.0, sigma=-1) + ssim.update((y_pred, y)) + ssim.compute() + + with pytest.raises(ValueError, match=r"Argument sigma should be either float or a sequence of float."): + ssim = SSIM(data_range=1.0, sigma=1) + ssim.update((y_pred, y)) + ssim.compute() + + +def test_ssim(): + ssim = SSIM(data_range=1.0) + device = "cuda" if torch.cuda.is_available() else "cpu" + y_pred = torch.rand(16, 3, 64, 64, device=device) + y = y_pred * 0.65 + ssim.update((y_pred, y)) + + np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy() + np_y = np_pred * 0.65 + np_ssim = ski_ssim(np_pred, np_y, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0) + + assert isinstance(ssim.compute(), torch.Tensor) + assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float64, device=device), atol=1e-4) + + ssim = SSIM(data_range=1.0, gaussian=False, kernel_size=7) + device = "cuda" if torch.cuda.is_available() else "cpu" + y_pred = torch.rand(16, 3, 227, 227, device=device) + y = y_pred * 0.65 + ssim.update((y_pred, y)) + + np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy() + np_y = np_pred * 0.65 + np_ssim = ski_ssim(np_pred, np_y, win_size=7, multichannel=True, gaussian_weights=False, data_range=1.0) + + assert isinstance(ssim.compute(), torch.Tensor) + assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float64, device=device), atol=1e-4) + + +def _test_distrib_integration(device, tol=1e-4): + from ignite.engine import Engine + + rank = idist.get_rank() + n_iters = 100 + s = 10 + offset = n_iters * s + + y_pred = torch.rand(offset * idist.get_world_size(), 3, 28, 28, dtype=torch.float, device=device) + y = y_pred * 0.65 + + def update(engine, i): + return ( + y_pred[i * s + offset * rank : (i + 1) * s + offset * rank], + y[i * s + offset * rank : (i + 1) * s + offset * rank], + ) + + engine = Engine(update) + SSIM(data_range=1.0).attach(engine, "ssim") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + + assert "ssim" in engine.state.metrics + res = engine.state.metrics["ssim"] + + np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy() + np_true = np_pred * 0.65 + true_res = ski_ssim(np_pred, np_true, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0) + + assert pytest.approx(res, abs=tol) == true_res + + engine = Engine(update) + SSIM(data_range=1.0, gaussian=False, kernel_size=7).attach(engine, "ssim") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + + assert "ssim" in engine.state.metrics + res = engine.state.metrics["ssim"] + + np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy() + np_true = np_pred * 0.65 + true_res = ski_ssim(np_pred, np_true, win_size=7, multichannel=True, gaussian_weights=False, data_range=1.0) + + assert pytest.approx(res, abs=tol) == true_res + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): + + device = "cuda:{}".format(local_rank) + _test_distrib_integration(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +def test_distrib_cpu(distributed_context_single_node_gloo): + device = "cpu" + _test_distrib_integration(device) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): + device = "cpu" + _test_distrib_integration(device) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): + device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + _test_distrib_integration(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_single_device_xla(): + device = idist.device() + _test_distrib_integration(device, tol=1e-3) + + +def _test_distrib_xla_nprocs(index): + device = idist.device() + _test_distrib_integration(device, tol=1e-3) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_xla_nprocs(xmp_executor): + n = int(os.environ["NUM_TPU_WORKERS"]) + xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)