Skip to content

Commit

Permalink
Change semantic ordering of kernel_size parameter in SSIM (#474)
Browse files Browse the repository at this point in the history
* implem

* changelog

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
SkafteNicki and pre-commit-ci[bot] authored Aug 25, 2021
1 parent 548a597 commit 94a158c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed metric hashing ([#478](https://github.com/PyTorchLightning/metrics/pull/478))


- Fixed the semantic ordering of kernel height and width in `SSIM` metric ([#474](https://github.com/PyTorchLightning/metrics/pull/474))


## [0.5.0] - 2021-08-09

### Added
Expand Down
60 changes: 49 additions & 11 deletions tests/image/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)


def _sk_ssim(preds, target, data_range, multichannel):
def _sk_ssim(preds, target, data_range, multichannel, kernel_size):
c, h, w = preds.shape[-3:]
sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
Expand All @@ -58,7 +58,7 @@ def _sk_ssim(preds, target, data_range, multichannel):
data_range=data_range,
multichannel=multichannel,
gaussian_weights=True,
win_size=11,
win_size=kernel_size,
sigma=1.5,
use_sample_covariance=False,
)
Expand All @@ -68,38 +68,39 @@ def _sk_ssim(preds, target, data_range, multichannel):
"preds, target, multichannel",
[(i.preds, i.target, i.multichannel) for i in _inputs],
)
@pytest.mark.parametrize("kernel_size", [5, 11])
class TestSSIM(MetricTester):
atol = 6e-5
atol = 6e-3

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_ssim(self, preds, target, multichannel, ddp, dist_sync_on_step):
def test_ssim(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
SSIM,
partial(_sk_ssim, data_range=1.0, multichannel=multichannel),
metric_args={"data_range": 1.0},
partial(_sk_ssim, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
dist_sync_on_step=dist_sync_on_step,
)

def test_ssim_functional(self, preds, target, multichannel):
def test_ssim_functional(self, preds, target, multichannel, kernel_size):
self.run_functional_metric_test(
preds,
target,
ssim,
partial(_sk_ssim, data_range=1.0, multichannel=multichannel),
metric_args={"data_range": 1.0},
partial(_sk_ssim, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
)

# SSIM half + cpu does not work due to missing support in torch.log
@pytest.mark.xfail(reason="SSIM metric does not support cpu + half precision")
def test_ssim_half_cpu(self, preds, target, multichannel):
def test_ssim_half_cpu(self, preds, target, multichannel, kernel_size):
self.run_precision_test_cpu(preds, target, SSIM, ssim, {"data_range": 1.0})

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_ssim_half_gpu(self, preds, target, multichannel):
def test_ssim_half_gpu(self, preds, target, multichannel, kernel_size):
self.run_precision_test_gpu(preds, target, SSIM, ssim, {"data_range": 1.0})


Expand Down Expand Up @@ -127,3 +128,40 @@ def test_ssim_invalid_inputs(pred, target, kernel, sigma):
target = torch.rand(target)
with pytest.raises(ValueError):
ssim(pred, target, kernel, sigma)


def test_ssim_unequal_kernel_size():
"""Test the case where kernel_size[0] != kernel_size[1]"""
preds = torch.tensor(
[
[
[
[1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0],
[1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0],
]
]
]
)
target = torch.tensor(
[
[
[
[1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
]
]
]
)
# kernel order matters
assert ssim(preds, target, kernel_size=(3, 5)) == torch.tensor(0.10814697)
assert ssim(preds, target, kernel_size=(5, 3)) != torch.tensor(0.10814697)
8 changes: 4 additions & 4 deletions torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ def _ssim_compute(
channel = preds.size(1)
dtype = preds.dtype
kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device)
pad_w = (kernel_size[0] - 1) // 2
pad_h = (kernel_size[1] - 1) // 2
pad_h = (kernel_size[0] - 1) // 2
pad_w = (kernel_size[1] - 1) // 2

preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode="reflect")
target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode="reflect")
preds = F.pad(preds, (pad_h, pad_h, pad_w, pad_w), mode="reflect")
target = F.pad(target, (pad_h, pad_h, pad_w, pad_w), mode="reflect")

input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)
outputs = F.conv2d(input_list, kernel, groups=channel)
Expand Down

0 comments on commit 94a158c

Please sign in to comment.