Skip to content

Commit 4a7428c

Browse files
Jeff Yangsdesrozisvfdev-5
authored
metrics: add SSIM (#1217)
* metrics: add SSIM * add scikit-image dependency * add distributed tests, fix docstring * .gitignore back to normal * Update ignite/metrics/ssim.py Co-authored-by: vfdev <vfdev.5@gmail.com> * .format(), separate functions * scalar input for kernel, sigma, fix py3.5 CI * apply suggestions * some fixes * fixed tpu tests * Minor code cosmetrics and raised err tolerance in tests * used list comprehension convolution, fixed tests * added uniform kernel, change tolerance, various image size tests * Update ignite/metrics/ssim.py Co-authored-by: vfdev <vfdev.5@gmail.com> * Update ignite/metrics/ssim.py Co-authored-by: vfdev <vfdev.5@gmail.com> * Fix flake8 Co-authored-by: Sylvain Desroziers <sylvain.desroziers@gmail.com> Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent ddb1eab commit 4a7428c

File tree

5 files changed

+353
-0
lines changed

5 files changed

+353
-0
lines changed

docs/source/metrics.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ Complete list of metrics
234234
- :class:`~ignite.metrics.Recall`
235235
- :class:`~ignite.metrics.RootMeanSquaredError`
236236
- :class:`~ignite.metrics.RunningAverage`
237+
- :class:`~ignite.metrics.SSIM`
237238
- :class:`~ignite.metrics.TopKCategoricalAccuracy`
238239
- :class:`~ignite.metrics.VariableAccumulation`
239240

@@ -278,6 +279,8 @@ Complete list of metrics
278279

279280
.. autoclass:: RunningAverage
280281

282+
.. autoclass:: SSIM
283+
281284
.. autoclass:: TopKCategoricalAccuracy
282285

283286
.. autoclass:: VariableAccumulation

ignite/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ignite.metrics.recall import Recall
1515
from ignite.metrics.root_mean_squared_error import RootMeanSquaredError
1616
from ignite.metrics.running_average import RunningAverage
17+
from ignite.metrics.ssim import SSIM
1718
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy
1819

1920
__all__ = [
@@ -39,4 +40,5 @@
3940
"RunningAverage",
4041
"VariableAccumulation",
4142
"Frequency",
43+
"SSIM",
4244
]

ignite/metrics/ssim.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from typing import Callable, Sequence, Union
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
from ignite.exceptions import NotComputableError
7+
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
8+
9+
__all__ = ["SSIM"]
10+
11+
12+
class SSIM(Metric):
13+
"""
14+
Computes Structual Similarity Index Measure
15+
16+
Args:
17+
data_range (int or float): Range of the image. Typically, ``1.0`` or ``255``.
18+
kernel_size (int or list or tuple of int): Size of the kernel. Default: (11, 11)
19+
sigma (float or list or tuple of float): Standard deviation of the gaussian kernel.
20+
Argument is used if ``gaussian=True``. Default: (1.5, 1.5)
21+
k1 (float): Parameter of SSIM. Default: 0.01
22+
k2 (float): Parameter of SSIM. Default: 0.03
23+
gaussian (bool): ``True`` to use gaussian kernel, ``False`` to use uniform kernel
24+
output_transform (callable, optional): A callable that is used to transform the
25+
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
26+
form expected by the metric.
27+
28+
Example:
29+
30+
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
31+
The output of the engine's ``process_function`` needs to be in the format of
32+
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.
33+
34+
``y_pred`` and ``y`` can be un-normalized or normalized image tensors. Depending on that, the user might need
35+
to adjust ``data_range``. ``y_pred`` and ``y`` should have the same shape.
36+
37+
.. code-block:: python
38+
39+
def process_function(engine, batch):
40+
# ...
41+
return y_pred, y
42+
engine = Engine(process_function)
43+
metric = SSIM(data_range=1.0)
44+
metric.attach(engine, "ssim")
45+
"""
46+
47+
def __init__(
48+
self,
49+
data_range: Union[int, float],
50+
kernel_size: Union[int, Sequence[int]] = (11, 11),
51+
sigma: Union[float, Sequence[float]] = (1.5, 1.5),
52+
k1: float = 0.01,
53+
k2: float = 0.03,
54+
gaussian: bool = True,
55+
output_transform: Callable = lambda x: x,
56+
):
57+
if isinstance(kernel_size, int):
58+
self.kernel_size = [kernel_size, kernel_size]
59+
elif isinstance(kernel_size, Sequence):
60+
self.kernel_size = kernel_size
61+
else:
62+
raise ValueError("Argument kernel_size should be either int or a sequence of int.")
63+
64+
if isinstance(sigma, float):
65+
self.sigma = [sigma, sigma]
66+
elif isinstance(sigma, Sequence):
67+
self.sigma = sigma
68+
else:
69+
raise ValueError("Argument sigma should be either float or a sequence of float.")
70+
71+
if any(x % 2 == 0 or x <= 0 for x in self.kernel_size):
72+
raise ValueError("Expected kernel_size to have odd positive number. Got {}.".format(kernel_size))
73+
74+
if any(y <= 0 for y in self.sigma):
75+
raise ValueError("Expected sigma to have positive number. Got {}.".format(sigma))
76+
77+
self.gaussian = gaussian
78+
self.c1 = (k1 * data_range) ** 2
79+
self.c2 = (k2 * data_range) ** 2
80+
self.pad_h = (self.kernel_size[0] - 1) // 2
81+
self.pad_w = (self.kernel_size[1] - 1) // 2
82+
self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma)
83+
super(SSIM, self).__init__(output_transform=output_transform)
84+
85+
@reinit__is_reduced
86+
def reset(self) -> None:
87+
self._sum_of_batchwise_ssim = 0.0
88+
self._num_examples = 0
89+
self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma)
90+
91+
def _uniform(self, kernel_size):
92+
max, min = 2.5, -2.5
93+
kernel = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32)
94+
for i, j in enumerate(kernel):
95+
if min <= j <= max:
96+
kernel[i] = 1 / (max - min)
97+
else:
98+
kernel[i] = 0
99+
100+
return kernel.unsqueeze(dim=0) # (1, kernel_size)
101+
102+
def _gaussian(self, kernel_size, sigma):
103+
kernel = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32)
104+
gauss = torch.exp(-kernel.pow(2) / (2 * pow(sigma, 2)))
105+
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
106+
107+
def _gaussian_or_uniform_kernel(self, kernel_size, sigma):
108+
if self.gaussian:
109+
kernel_x = self._gaussian(kernel_size[0], sigma[0])
110+
kernel_y = self._gaussian(kernel_size[1], sigma[1])
111+
else:
112+
kernel_x = self._uniform(kernel_size[0])
113+
kernel_y = self._uniform(kernel_size[1])
114+
115+
return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size)
116+
117+
@reinit__is_reduced
118+
def update(self, output: Sequence[torch.Tensor]) -> None:
119+
y_pred, y = output
120+
if y_pred.dtype != y.dtype:
121+
raise TypeError(
122+
"Expected y_pred and y to have the same data type. Got y_pred: {} and y: {}.".format(
123+
y_pred.dtype, y.dtype
124+
)
125+
)
126+
127+
if y_pred.shape != y.shape:
128+
raise ValueError(
129+
"Expected y_pred and y to have the same shape. Got y_pred: {} and y: {}.".format(y_pred.shape, y.shape)
130+
)
131+
132+
if len(y_pred.shape) != 4 or len(y.shape) != 4:
133+
raise ValueError(
134+
"Expected y_pred and y to have BxCxHxW shape. Got y_pred: {} and y: {}.".format(y_pred.shape, y.shape)
135+
)
136+
137+
channel = y_pred.size(1)
138+
if len(self._kernel.shape) < 4:
139+
self._kernel = self._kernel.expand(channel, 1, -1, -1).to(device=y_pred.device)
140+
141+
y_pred = F.pad(y_pred, (self.pad_w, self.pad_w, self.pad_h, self.pad_h), mode="reflect")
142+
y = F.pad(y, (self.pad_w, self.pad_w, self.pad_h, self.pad_h), mode="reflect")
143+
144+
input_list = torch.cat([y_pred, y, y_pred * y_pred, y * y, y_pred * y])
145+
outputs = F.conv2d(input_list, self._kernel, groups=channel)
146+
147+
output_list = [outputs[x * y_pred.size(0) : (x + 1) * y_pred.size(0)] for x in range(len(outputs))]
148+
149+
mu_pred_sq = output_list[0].pow(2)
150+
mu_target_sq = output_list[1].pow(2)
151+
mu_pred_target = output_list[0] * output_list[1]
152+
153+
sigma_pred_sq = output_list[2] - mu_pred_sq
154+
sigma_target_sq = output_list[3] - mu_target_sq
155+
sigma_pred_target = output_list[4] - mu_pred_target
156+
157+
a1 = 2 * mu_pred_target + self.c1
158+
a2 = 2 * sigma_pred_target + self.c2
159+
b1 = mu_pred_sq + mu_target_sq + self.c1
160+
b2 = sigma_pred_sq + sigma_target_sq + self.c2
161+
162+
ssim_idx = (a1 * a2) / (b1 * b2)
163+
self._sum_of_batchwise_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64)
164+
self._num_examples += y.shape[0]
165+
166+
@sync_all_reduce("_sum_of_batchwise_ssim", "_num_examples")
167+
def compute(self) -> torch.Tensor:
168+
if self._num_examples == 0:
169+
raise NotComputableError("SSIM must have at least one example before it can be computed.")
170+
return torch.sum(self._sum_of_batchwise_ssim / self._num_examples)

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ neptune-client
1818
tensorboard
1919
pynvml; python_version > '3.5'
2020
trains>=0.15.1
21+
scikit-image>=0.15.0
2122
# Examples dependencies
2223
pandas
2324
gym

tests/ignite/metrics/test_ssim.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import os
2+
3+
import pytest
4+
import torch
5+
6+
import ignite.distributed as idist
7+
from ignite.exceptions import NotComputableError
8+
from ignite.metrics import SSIM
9+
10+
try:
11+
from skimage.metrics import structural_similarity as ski_ssim
12+
except ImportError:
13+
from skimage.measure import compare_ssim as ski_ssim
14+
15+
16+
def test_zero_div():
17+
ssim = SSIM(data_range=1.0)
18+
with pytest.raises(NotComputableError):
19+
ssim.compute()
20+
21+
22+
def test_invalid_ssim():
23+
y_pred = torch.rand(16, 1, 32, 32)
24+
y = y_pred + 0.125
25+
with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got 10."):
26+
ssim = SSIM(data_range=1.0, kernel_size=10)
27+
ssim.update((y_pred, y))
28+
ssim.compute()
29+
30+
with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got -1."):
31+
ssim = SSIM(data_range=1.0, kernel_size=-1)
32+
ssim.update((y_pred, y))
33+
ssim.compute()
34+
35+
with pytest.raises(ValueError, match=r"Argument kernel_size should be either int or a sequence of int."):
36+
ssim = SSIM(data_range=1.0, kernel_size=1.0)
37+
ssim.update((y_pred, y))
38+
ssim.compute()
39+
40+
with pytest.raises(ValueError, match=r"Argument sigma should be either float or a sequence of float."):
41+
ssim = SSIM(data_range=1.0, sigma=-1)
42+
ssim.update((y_pred, y))
43+
ssim.compute()
44+
45+
with pytest.raises(ValueError, match=r"Argument sigma should be either float or a sequence of float."):
46+
ssim = SSIM(data_range=1.0, sigma=1)
47+
ssim.update((y_pred, y))
48+
ssim.compute()
49+
50+
51+
def test_ssim():
52+
ssim = SSIM(data_range=1.0)
53+
device = "cuda" if torch.cuda.is_available() else "cpu"
54+
y_pred = torch.rand(16, 3, 64, 64, device=device)
55+
y = y_pred * 0.65
56+
ssim.update((y_pred, y))
57+
58+
np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
59+
np_y = np_pred * 0.65
60+
np_ssim = ski_ssim(np_pred, np_y, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0)
61+
62+
assert isinstance(ssim.compute(), torch.Tensor)
63+
assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float64, device=device), atol=1e-4)
64+
65+
ssim = SSIM(data_range=1.0, gaussian=False, kernel_size=7)
66+
device = "cuda" if torch.cuda.is_available() else "cpu"
67+
y_pred = torch.rand(16, 3, 227, 227, device=device)
68+
y = y_pred * 0.65
69+
ssim.update((y_pred, y))
70+
71+
np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
72+
np_y = np_pred * 0.65
73+
np_ssim = ski_ssim(np_pred, np_y, win_size=7, multichannel=True, gaussian_weights=False, data_range=1.0)
74+
75+
assert isinstance(ssim.compute(), torch.Tensor)
76+
assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float64, device=device), atol=1e-4)
77+
78+
79+
def _test_distrib_integration(device, tol=1e-4):
80+
from ignite.engine import Engine
81+
82+
rank = idist.get_rank()
83+
n_iters = 100
84+
s = 10
85+
offset = n_iters * s
86+
87+
y_pred = torch.rand(offset * idist.get_world_size(), 3, 28, 28, dtype=torch.float, device=device)
88+
y = y_pred * 0.65
89+
90+
def update(engine, i):
91+
return (
92+
y_pred[i * s + offset * rank : (i + 1) * s + offset * rank],
93+
y[i * s + offset * rank : (i + 1) * s + offset * rank],
94+
)
95+
96+
engine = Engine(update)
97+
SSIM(data_range=1.0).attach(engine, "ssim")
98+
99+
data = list(range(n_iters))
100+
engine.run(data=data, max_epochs=1)
101+
102+
assert "ssim" in engine.state.metrics
103+
res = engine.state.metrics["ssim"]
104+
105+
np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
106+
np_true = np_pred * 0.65
107+
true_res = ski_ssim(np_pred, np_true, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0)
108+
109+
assert pytest.approx(res, abs=tol) == true_res
110+
111+
engine = Engine(update)
112+
SSIM(data_range=1.0, gaussian=False, kernel_size=7).attach(engine, "ssim")
113+
114+
data = list(range(n_iters))
115+
engine.run(data=data, max_epochs=1)
116+
117+
assert "ssim" in engine.state.metrics
118+
res = engine.state.metrics["ssim"]
119+
120+
np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
121+
np_true = np_pred * 0.65
122+
true_res = ski_ssim(np_pred, np_true, win_size=7, multichannel=True, gaussian_weights=False, data_range=1.0)
123+
124+
assert pytest.approx(res, abs=tol) == true_res
125+
126+
127+
@pytest.mark.distributed
128+
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
129+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
130+
def test_distrib_gpu(local_rank, distributed_context_single_node_nccl):
131+
132+
device = "cuda:{}".format(local_rank)
133+
_test_distrib_integration(device)
134+
135+
136+
@pytest.mark.distributed
137+
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
138+
def test_distrib_cpu(distributed_context_single_node_gloo):
139+
device = "cpu"
140+
_test_distrib_integration(device)
141+
142+
143+
@pytest.mark.multinode_distributed
144+
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
145+
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
146+
def test_multinode_distrib_cpu(distributed_context_multi_node_gloo):
147+
device = "cpu"
148+
_test_distrib_integration(device)
149+
150+
151+
@pytest.mark.multinode_distributed
152+
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
153+
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
154+
def test_multinode_distrib_gpu(distributed_context_multi_node_nccl):
155+
device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])
156+
_test_distrib_integration(device)
157+
158+
159+
@pytest.mark.tpu
160+
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
161+
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
162+
def test_distrib_single_device_xla():
163+
device = idist.device()
164+
_test_distrib_integration(device, tol=1e-3)
165+
166+
167+
def _test_distrib_xla_nprocs(index):
168+
device = idist.device()
169+
_test_distrib_integration(device, tol=1e-3)
170+
171+
172+
@pytest.mark.tpu
173+
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
174+
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
175+
def test_distrib_xla_nprocs(xmp_executor):
176+
n = int(os.environ["NUM_TPU_WORKERS"])
177+
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)

0 commit comments

Comments
 (0)