Skip to content

Commit

Permalink
⚡️ Add N-dimensional images support to (MS-)SSIM
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Jan 21, 2021
1 parent fb2a431 commit 3890868
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 16 deletions.
38 changes: 22 additions & 16 deletions piqa/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
import torch.nn as nn
import torch.nn.functional as F

from piqa.utils import _jit, build_reduce, gaussian_kernel, channel_convs
from piqa.utils import (
_jit,
build_reduce,
gaussian_kernel,
kernel_views,
channel_convs,
)

from typing import List, Tuple

Expand Down Expand Up @@ -58,8 +64,8 @@ def _ssim(
In practice, SSIM and CS are averaged over the image width and height.
Args:
x: An input tensor, \((N, C, H, W)\).
y: A target tensor, \((N, C, H, W)\).
x: An input tensor, \((N, C, H, *)\).
y: A target tensor, \((N, C, H, *)\).
kernel: A smoothing kernel, \((C, 1, K)\).
E.g. `piqa.utils.gaussian_kernel`.
value_range: The value range \(L\) of the inputs (usually 1. or 255).
Expand All @@ -71,8 +77,8 @@ def _ssim(
The channel-wise SSIM and CS tensors, both \((N, C)\).
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> x = torch.rand(5, 3, 64, 64, 64)
>>> y = torch.rand(5, 3, 64, 64, 64)
>>> kernel = gaussian_kernel(7).repeat(3, 1, 1)
>>> ss, cs = _ssim(x, y, kernel)
>>> ss.size(), cs.size()
Expand All @@ -82,7 +88,7 @@ def _ssim(
c1 = (k1 * value_range) ** 2
c2 = (k2 * value_range) ** 2

window = [kernel.unsqueeze(-1), kernel.unsqueeze(-2)]
window = kernel_views(kernel, x.dim() - 2)

# Mean (mu)
mu_x = channel_convs(x, window)
Expand Down Expand Up @@ -132,8 +138,8 @@ def _ms_ssim(
the original tensors by a factor \(2^{i - 1}\).
Args:
x: An input tensor, \((N, C, H, W)\).
y: A target tensor, \((N, C, H, W)\).
x: An input tensor, \((N, C, H, *)\).
y: A target tensor, \((N, C, H, *)\).
kernel: A smoothing kernel, \((C, 1, K)\).
E.g. `piqa.utils.gaussian_kernel`.
weights: The weights \(\gamma_i\) of the scales, \((M,)\).
Expand Down Expand Up @@ -185,8 +191,8 @@ def ssim(
r"""Returns the SSIM between \(x\) and \(y\).
Args:
x: An input tensor, \((N, C, H, W)\).
y: A target tensor, \((N, C, H, W)\).
x: An input tensor, \((N, C, H, *)\).
y: A target tensor, \((N, C, H, *)\).
`**kwargs` are transmitted to `SSIM`.
Expand Down Expand Up @@ -214,8 +220,8 @@ def ms_ssim(
r"""Returns the MS-SSIM between \(x\) and \(y\).
Args:
x: An input tensor, \((N, C, H, W)\).
y: A target tensor, \((N, C, H, W)\).
x: An input tensor, \((N, C, H, *)\).
y: A target tensor, \((N, C, H, *)\).
`**kwargs` are transmitted to `MS_SSIM`.
Expand Down Expand Up @@ -249,8 +255,8 @@ class SSIM(nn.Module):
`**kwargs` are transmitted to `_ssim`.
Shapes:
* Input: \((N, C, H, W)\)
* Target: \((N, C, H, W)\)
* Input: \((N, C, H, *)\)
* Target: \((N, C, H, *)\)
* Output: \((N,)\) or \(()\) depending on `reduction`
Example:
Expand Down Expand Up @@ -314,8 +320,8 @@ class MS_SSIM(nn.Module):
`**kwargs` are transmitted to `_ms_ssim`.
Shapes:
* Input: \((N, C, H, W)\)
* Target: \((N, C, H, W)\)
* Input: \((N, C, H, *)\)
* Target: \((N, C, H, *)\)
* Output: \((N,)\) or \(()\) depending on `reduction`
Example:
Expand Down
40 changes: 40 additions & 0 deletions piqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,46 @@ def gaussian_kernel(
return kernel


def kernel_views(kernel: torch.Tensor, n: int = 2) -> List[torch.Tensor]:
r"""Returns the \(N\)-dimensional views of the 1-dimensional
kernel `kernel`.
Args:
kernel: A kernel, \((C, 1, K)\).
n: The number of dimensions \(N\).
Returns:
The list of views, each \((C, 1, \underbrace{1, \dots, 1}_{i}, K,
\underbrace{1, \dots, 1}_{N - i - 1})\).
Example:
>>> kernel = gaussian_kernel(5, sigma=1.5).repeat(3, 1, 1)
>>> kernel.size()
torch.Size([3, 1, 5])
>>> views = kernel_views(kernel, n=2)
>>> views[0].size(), views[1].size()
(torch.Size([3, 1, 5, 1]), torch.Size([3, 1, 1, 5]))
"""

if n == 1:
return [kernel]
elif n == 2:
return [kernel.unsqueeze(-1), kernel.unsqueeze(-2)]

# elif n > 2:
c, _, k = kernel.size()

shape: List[int] = [c, 1] + [1] * n
views = []

for i in range(2, n + 2):
shape[i] = k
views.append(kernel.view(shape))
shape[i] = 1

return views


def haar_kernel(size: int) -> torch.Tensor:
r"""Returns the horizontal Haar kernel.
Expand Down

0 comments on commit 3890868

Please sign in to comment.