diff --git a/CHANGELOG.md b/CHANGELOG.md index 536a8253bab..9cd9d8b491b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed metric hashing ([#478](https://github.com/PyTorchLightning/metrics/pull/478)) +- Fixed the semantic ordering of kernel height and width in `SSIM` metric ([#474](https://github.com/PyTorchLightning/metrics/pull/474)) + + ## [0.5.0] - 2021-08-09 ### Added diff --git a/tests/image/test_ssim.py b/tests/image/test_ssim.py index a07de994e51..6a8360f0caa 100644 --- a/tests/image/test_ssim.py +++ b/tests/image/test_ssim.py @@ -44,7 +44,7 @@ ) -def _sk_ssim(preds, target, data_range, multichannel): +def _sk_ssim(preds, target, data_range, multichannel, kernel_size): c, h, w = preds.shape[-3:] sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy() @@ -58,7 +58,7 @@ def _sk_ssim(preds, target, data_range, multichannel): data_range=data_range, multichannel=multichannel, gaussian_weights=True, - win_size=11, + win_size=kernel_size, sigma=1.5, use_sample_covariance=False, ) @@ -68,38 +68,39 @@ def _sk_ssim(preds, target, data_range, multichannel): "preds, target, multichannel", [(i.preds, i.target, i.multichannel) for i in _inputs], ) +@pytest.mark.parametrize("kernel_size", [5, 11]) class TestSSIM(MetricTester): - atol = 6e-5 + atol = 6e-3 @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_ssim(self, preds, target, multichannel, ddp, dist_sync_on_step): + def test_ssim(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_step): self.run_class_metric_test( ddp, preds, target, SSIM, - partial(_sk_ssim, data_range=1.0, multichannel=multichannel), - metric_args={"data_range": 1.0}, + partial(_sk_ssim, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size), + metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)}, dist_sync_on_step=dist_sync_on_step, ) - def test_ssim_functional(self, preds, target, multichannel): + def test_ssim_functional(self, preds, target, multichannel, kernel_size): self.run_functional_metric_test( preds, target, ssim, - partial(_sk_ssim, data_range=1.0, multichannel=multichannel), - metric_args={"data_range": 1.0}, + partial(_sk_ssim, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size), + metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)}, ) # SSIM half + cpu does not work due to missing support in torch.log @pytest.mark.xfail(reason="SSIM metric does not support cpu + half precision") - def test_ssim_half_cpu(self, preds, target, multichannel): + def test_ssim_half_cpu(self, preds, target, multichannel, kernel_size): self.run_precision_test_cpu(preds, target, SSIM, ssim, {"data_range": 1.0}) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - def test_ssim_half_gpu(self, preds, target, multichannel): + def test_ssim_half_gpu(self, preds, target, multichannel, kernel_size): self.run_precision_test_gpu(preds, target, SSIM, ssim, {"data_range": 1.0}) @@ -127,3 +128,40 @@ def test_ssim_invalid_inputs(pred, target, kernel, sigma): target = torch.rand(target) with pytest.raises(ValueError): ssim(pred, target, kernel, sigma) + + +def test_ssim_unequal_kernel_size(): + """Test the case where kernel_size[0] != kernel_size[1]""" + preds = torch.tensor( + [ + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0], + ] + ] + ] + ) + target = torch.tensor( + [ + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0], + [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], + ] + ] + ] + ) + # kernel order matters + assert ssim(preds, target, kernel_size=(3, 5)) == torch.tensor(0.10814697) + assert ssim(preds, target, kernel_size=(5, 3)) != torch.tensor(0.10814697) diff --git a/torchmetrics/functional/image/ssim.py b/torchmetrics/functional/image/ssim.py index 43289a3268f..2e975d2add1 100644 --- a/torchmetrics/functional/image/ssim.py +++ b/torchmetrics/functional/image/ssim.py @@ -146,11 +146,11 @@ def _ssim_compute( channel = preds.size(1) dtype = preds.dtype kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device) - pad_w = (kernel_size[0] - 1) // 2 - pad_h = (kernel_size[1] - 1) // 2 + pad_h = (kernel_size[0] - 1) // 2 + pad_w = (kernel_size[1] - 1) // 2 - preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode="reflect") - target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode="reflect") + preds = F.pad(preds, (pad_h, pad_h, pad_w, pad_w), mode="reflect") + target = F.pad(target, (pad_h, pad_h, pad_w, pad_w), mode="reflect") input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) outputs = F.conv2d(input_list, kernel, groups=channel)