Skip to content

Commit

Permalink
✨ Mean Deviation Similarity Index (MDSI)
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Dec 5, 2020
1 parent 3d1fd4c commit e79b340
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 0 deletions.
136 changes: 136 additions & 0 deletions spiq/mdsi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
r"""Mean Deviation Similarity Index (MDSI)
This module implements the MDSI in PyTorch.
References:
[1] Mean Deviation Similarity Index:
Efficient and Reliable Full-Reference Image Quality Evaluator
(Nafchi et al., 2016)
https://arxiv.org/abs/1608.07433
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from spiq.utils import build_reduce, prewitt_kernel, gradient2d, tensor_norm

_LHM_WEIGHTS = torch.FloatTensor([
[0.2989, 0.587, 0.114],
[0.3, 0.04, -0.35],
[0.34, -0.6, 0.17],
])


def mdsi(
x: torch.Tensor,
y: torch.Tensor,
value_range: float = 1.,
combine: str = 'sum',
c1: float = 0.00215, # 140. / (255. ** 2)
c2: float = 0.00085, # 55. / (255. ** 2)
c3: float = 0.00846, # 550. / (255. ** 2)
alpha: float = 0.6,
beta: float = 0.1,
gamma: float = 0.2,
rho: float = 1.,
q: float = 0.25,
o: float = 0.25,
) -> torch.Tensor:
r"""Returns the MDSI between `x` and `y`.
Args:
x: An input tensor, (N, 3, H, W).
y: A target tensor, (N, 3, H, W).
value_range: The value range of the inputs (usually 1. or 255).
combine: The combination scheme of the gradient similarity (GS) and
the chromaticity similarity (CS) (`'sum'` or `'prod'`).
For the remaining arguments, refer to [1].
"""

_, _, h, w = x.size()

# Downsample
M = max(1, min(h, w) // 256)
padding = (0, M - (w - 1 % M) + 1, 0, M - (h - 1 % M) + 1)

if sum(padding) > 0:
x = F.pad(x, pad=padding)
y = F.pad(y, pad=padding)

x = F.avg_pool2d(x, kernel_size=M)
y = F.avg_pool2d(y, kernel_size=M)

# RGB to LHM
lhm_weights = _LHM_WEIGHTS.to(x.device).view(3, 3, 1, 1)
lhm_weights /= value_range

x = F.conv2d(x, lhm_weights)
y = F.conv2d(y, lhm_weights)

# Gradient magnitude
kernel = prewitt_kernel()
kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1).to(x.device)

gm_x = tensor_norm(gradient2d(x[:, :1], kernel), dim=1)
gm_y = tensor_norm(gradient2d(y[:, :1], kernel), dim=1)
gm_avg = tensor_norm(gradient2d((x + y)[:, :1] / 2., kernel), dim=1)

gm_x_sq, gm_y_sq, gm_avg_sq = gm_x ** 2, gm_y ** 2, gm_avg ** 2

# Gradient similarity
gs_x_y = (2. * gm_x * gm_y + c1) / (gm_x_sq + gm_y_sq + c1)
gs_x_avg = (2. * gm_x * gm_avg + c2) / (gm_x_sq + gm_avg_sq + c2)
gs_y_avg = (2. * gm_y * gm_avg + c2) / (gm_y_sq + gm_avg_sq + c2)

gs = gs_x_y + gs_x_avg - gs_y_avg

# Chromaticity similarity
cs_num = 2. * (x[:, 1:] * y[:, 1:]).sum(1) + c3
cs_den = (x[:, 1:] ** 2 + y[:, 1:] ** 2).sum(1) + c3
cs = cs_num / cs_den

# Gradient-chromaticity similarity
gs, cs = gs.type(torch.cfloat), cs.type(torch.cfloat)

if combine == 'prod':
gcs = (gs ** gamma) * (cs ** beta)
else: # combine == 'sum'
gcs = alpha * gs + (1. - alpha) * cs

# Mean deviation similarity
gcs_q = gcs ** q
score = (gcs_q - gcs_q.mean((-1, -2), keepdim=True)).abs()
mds = (score ** rho).mean((-1, -2)) ** (o / rho)

return mds


class MDSI(nn.Module):
r"""Creates a criterion that measures the MDSI
between an input and a target.
Args:
reduction: A reduction type (`'mean'`, `'sum'` or `'none'`).
`**kwargs` are transmitted to `mdsi`.
Call:
The input and target tensors should be of shape (N, 3, H, W).
"""

def __init__(self, reduction: str = 'mean', **kwargs):
super().__init__()

self.reduce = build_reduce(reduction)
self.kwargs = kwargs

def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
l = mdsi(input, target, **self.kwargs)

return self.reduce(l)
54 changes: 54 additions & 0 deletions spiq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Callable, List, Tuple

Expand Down Expand Up @@ -64,6 +65,59 @@ def gaussian_kernel(
return kernel


def gradient2d(x: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
r"""Returns the 2D gradient of `x` with respect to `kernel`.
Args:
x: An input tensor, (N, 1, H, W).
kernel: A 2D derivative kernel, (2, K, K).
"""

return F.conv2d(x, kernel, padding=kernel.size(-1) // 2)


def prewitt_kernel() -> torch.Tensor:
r"""Returns the (horizontal) 3x3 Prewitt kernel.
Wikipedia:
https://en.wikipedia.org/wiki/Prewitt_operator
"""

return torch.Tensor([
[1., 0., -1.],
[1., 0., -1.],
[1., 0., -1.],
]) / 3


def sobel_kernel() -> torch.Tensor:
r"""Returns the (horizontal) 3x3 Sobel kernel.
Wikipedia:
https://en.wikipedia.org/wiki/Sobel_operator
"""

return torch.Tensor([
[1., 0., -1.],
[2., 0., -2.],
[1., 0., -1.],
]) / 4


def scharr_kernel() -> torch.Tensor:
r"""Returns the (horizontal) 3x3 Scharr kernel.
Wikipedia:
https://en.wikipedia.org/wiki/Scharr_operator
"""

return torch.Tensor([
[3., 0., -3.],
[10., 0., -10.],
[3., 0., -3.],
]) / 16


def tensor_norm(
x: torch.Tensor,
dim: Tuple[int, ...] = (),
Expand Down

0 comments on commit e79b340

Please sign in to comment.