Skip to content

Commit

Permalink
🐛 Fix MDSI backward pass
Browse files Browse the repository at this point in the history
As of PyTorch 1.7.0, some operations on complex tensors (torch.cfloat) do not support CUDA, automatic differentiation and/or JIT (f73b13c, 0e89000). To overcome these limitations, complex tensors were replaced by real tensors.

Unit tests with backward pass have also been added.
  • Loading branch information
francois-rozet committed Jan 15, 2021
1 parent c829915 commit c6d924d
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 66 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@ The `piqa` package is divided in several submodules, each of which implements th
import torch
from piqa import psnr, ssim

x = torch.rand(3, 3, 256, 256).cuda()
y = torch.rand(3, 3, 256, 256).cuda()
x = torch.rand(3, 3, 256, 256, requires_grad=True).cuda()
y = torch.rand(3, 3, 256, 256, requires_grad=True).cuda()

# PSNR function
l = psnr.psnr(x, y)

# SSIM instantiable object
criterion = ssim.SSIM().cuda()
l = criterion(x, y)
l.backward()
```

### Metrics
Expand Down
2 changes: 1 addition & 1 deletion piqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
specific image quality assessement metric.
"""

__version__ = '1.0.12'
__version__ = '1.0.13'
10 changes: 6 additions & 4 deletions piqa/gmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,12 @@ class GMSD(nn.Module):
Example:
>>> criterion = GMSD().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> y = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
>>> l.backward()
"""

def __init__(
Expand Down Expand Up @@ -315,11 +316,12 @@ class MSGMSD(nn.Module):
Example:
>>> criterion = MSGMSD().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> y = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
>>> l.backward()
"""

def __init__(
Expand Down
20 changes: 12 additions & 8 deletions piqa/haarpsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _haarpsi(
c *= value_range ** 2

# Y
y_x, y_y = x[:, :1], y[:, :1]

## Gradient(s)
g_xy: List[Tuple[Tensor, Tensor]] = []
Expand All @@ -82,10 +83,10 @@ def _haarpsi(
### Haar filter (gradient)
pad = kernel_size // 2

g_x = channel_conv(x[:, :1], kernel, padding=pad).abs()
g_y = channel_conv(y[:, :1], kernel, padding=pad).abs()
g_x = channel_conv(y_x, kernel, padding=pad)[..., 1:, 1:].abs()
g_y = channel_conv(y_y, kernel, padding=pad)[..., 1:, 1:].abs()

g_xy.append((g_x[..., 1:, 1:], g_y[..., 1:, 1:]))
g_xy.append((g_x, g_y))

## Gradient similarity(ies)
gs = []
Expand All @@ -100,15 +101,17 @@ def _haarpsi(

# IQ
if x.size(1) == 3:
iq_x, iq_y = x[:, 1:], y[:, 1:]

## Mean filter
m_x = F.avg_pool2d(x[:, 1:], 2, stride=1, padding=1).abs()
m_y = F.avg_pool2d(y[:, 1:], 2, stride=1, padding=1).abs()
m_x = F.avg_pool2d(iq_x, 2, stride=1, padding=1)[..., 1:, 1:].abs()
m_y = F.avg_pool2d(iq_y, 2, stride=1, padding=1)[..., 1:, 1:].abs()

## Chromatic similarity(ies)
cs = (2. * m_x * m_y + c) / (m_x ** 2 + m_y ** 2 + c)

## Local similarity(ies)
ls = torch.cat([ls, cs[..., 1:, 1:].mean(1, True)], dim=1) # (N, 3, H, W)
ls = torch.cat([ls, cs.mean(1, True)], dim=1) # (N, 3, H, W)

## Weight(s)
w = torch.cat([w, w.mean(1, True)], dim=1) # (N, 3, H, W)
Expand Down Expand Up @@ -180,11 +183,12 @@ class HaarPSI(nn.Module):
Example:
>>> criterion = HaarPSI().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> y = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
>>> l.backward()
"""

def __init__(
Expand Down
5 changes: 3 additions & 2 deletions piqa/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ class LPIPS(nn.Module):
Example:
>>> criterion = LPIPS().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> y = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
>>> l.backward()
"""

def __init__(
Expand Down
34 changes: 21 additions & 13 deletions piqa/mdsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
gradient_kernel,
channel_conv,
tensor_norm,
cstack,
cprod,
cpow,
cabs,
)

_LHM_WEIGHTS = torch.FloatTensor([
Expand All @@ -32,6 +35,7 @@
])


@_jit
def _mdsi(
x: torch.Tensor,
y: torch.Tensor,
Expand Down Expand Up @@ -77,13 +81,16 @@ def _mdsi(
c2 *= value_range ** 2
c3 *= value_range ** 2

l_x, hm_x = x[:, :1], x[:, 1:]
l_y, hm_y = y[:, :1], y[:, 1:]

# Gradient magnitude
pad = kernel.size(-1) // 2

gm_x = tensor_norm(channel_conv(x[:, :1], kernel, padding=pad), dim=[1])
gm_y = tensor_norm(channel_conv(y[:, :1], kernel, padding=pad), dim=[1])
gm_x = tensor_norm(channel_conv(l_x, kernel, padding=pad), dim=[1])
gm_y = tensor_norm(channel_conv(l_y, kernel, padding=pad), dim=[1])
gm_avg = tensor_norm(
channel_conv((x + y)[:, :1] / 2., kernel, padding=pad),
channel_conv((l_x + l_y) / 2., kernel, padding=pad),
dim=[1],
)

Expand All @@ -97,24 +104,24 @@ def _mdsi(
gs = gs_x_y + gs_x_avg - gs_y_avg

# Chromaticity similarity
cs_num = 2. * (x[:, 1:] * y[:, 1:]).sum(1) + c3
cs_den = (x[:, 1:] ** 2 + y[:, 1:] ** 2).sum(1) + c3
cs_num = 2. * (hm_x * hm_y).sum(1) + c3
cs_den = (hm_x ** 2 + hm_y ** 2).sum(1) + c3
cs = cs_num / cs_den

# Gradient-chromaticity similarity
gs, cs = gs.type(torch.cfloat), cs.type(torch.cfloat)
gs = cstack(gs, torch.zeros_like(gs))
cs = cstack(cs, torch.zeros_like(cs))

if combination == 'prod':
gcs = (gs ** gamma) * (cs ** beta)
gcs = cprod(cpow(gs, gamma), cpow(cs, beta))
else: # combination == 'sum'
gcs = alpha * gs + (1. - alpha) * cs

# Mean deviation similarity
gcs_q = cpow(gcs, q)
gcs_q_avg = torch.view_as_real(gcs_q).mean((-2, -3), True)
gcs_q_avg = torch.view_as_complex(gcs_q_avg)
score = (gcs_q - gcs_q_avg).abs()
mds = (score ** rho).mean((-1, -2)) ** (o / rho)
gcs_q_avg = gcs_q.mean((-2, -3), True)
score = cabs(gcs_q - gcs_q_avg, squared=True) ** (rho / 2)
mds = score.mean((-1, -2)) ** (o / rho)

return mds

Expand Down Expand Up @@ -182,11 +189,12 @@ class MDSI(nn.Module):
Example:
>>> criterion = MDSI().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> y = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
>>> l.backward()
"""

def __init__(
Expand Down
5 changes: 3 additions & 2 deletions piqa/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ class PSNR(nn.Module):
Example:
>>> criterion = PSNR()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> y = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
>>> l.backward()
"""

def __init__(self, reduction: str = 'mean', **kwargs):
Expand Down
9 changes: 5 additions & 4 deletions piqa/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ class SSIM(nn.Module):
Example:
>>> criterion = SSIM().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> y = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
Expand Down Expand Up @@ -307,11 +307,12 @@ class MSSSIM(nn.Module):
Example:
>>> criterion = MSSSIM().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> y = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
>>> l.backward()
"""

def __init__(
Expand Down
11 changes: 6 additions & 5 deletions piqa/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def tv(x: torch.Tensor, norm: str = 'L2') -> torch.Tensor:
w_var = x[..., :, 1:] - x[..., :, :-1]
h_var = x[..., 1:, :] - x[..., :-1, :]

if norm in ['L2', 'L2_squared']:
w_var = w_var ** 2
h_var = h_var ** 2
else: # norm == 'L1'
if norm == 'L1':
w_var = w_var.abs()
h_var = h_var.abs()
else: # norm in ['L2', 'L2_squared']
w_var = w_var ** 2
h_var = h_var ** 2

var = w_var.sum(dim=(-1, -2, -3)) + h_var.sum(dim=(-1, -2, -3))

Expand All @@ -61,10 +61,11 @@ class TV(nn.Module):
Example:
>>> criterion = TV()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
>>> l = criterion(x)
>>> l.size()
torch.Size([])
>>> l.backward()
"""

def __init__(self, reduction: str = 'mean', **kwargs):
Expand Down
Loading

0 comments on commit c6d924d

Please sign in to comment.