Skip to content

Commit

Permalink
⚡️ Register convolution kernels in buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Dec 10, 2020
1 parent a78e4e3 commit e6179ae
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 81 deletions.
120 changes: 87 additions & 33 deletions piqa/gmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,71 @@

from piqa.utils import build_reduce, prewitt_kernel, filter2d, tensor_norm

_L_WEIGHTS = torch.FloatTensor([0.2989, 0.587, 0.114])
_L_WEIGHTS = torch.FloatTensor([0.299, 0.587, 0.114])


def gmsd(
def _gmsd(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor,
value_range: float = 1.,
c: float = 0.00261, # 170. / (255. ** 2)
) -> torch.Tensor:
r"""Returns the GMSD between `x` and `y`,
without downsampling and color space conversion.
`_gmsd` is an auxiliary function for `gmsd` and `GMSD`.
Args:
x: An input tensor, (N, 1, H, W).
y: A target tensor, (N, 1, H, W).
kernel: A 2D gradient kernel, (2, 1, K, K).
value_range: The value range of the inputs (usually 1. or 255).
For the remaining arguments, refer to [1].
Example:
>>> x = torch.rand(5, 1, 256, 256)
>>> y = torch.rand(5, 1, 256, 256)
>>> kernel = torch.rand(2, 1, 3, 3)
>>> l = _gmsd(x, y, kernel)
>>> l.size()
torch.Size([5])
"""

c *= value_range ** 2

# Gradient magnitude
pad = kernel.size(-1) // 2

gm_x = tensor_norm(filter2d(x, kernel, padding=pad), dim=1)
gm_y = tensor_norm(filter2d(y, kernel, padding=pad), dim=1)

# Gradient magnitude similarity
gms = (2. * gm_x * gm_y + c) / (gm_x ** 2 + gm_y ** 2 + c)

# Gradient magnitude similarity diviation
gmsd = (gms - gms.mean((-1, -2), keepdim=True)) ** 2
gmsd = torch.sqrt(gmsd.mean((-1, -2)))

return gmsd


def gmsd(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
r"""Returns the GMSD between `x` and `y`.
Args:
x: An input tensor, (N, 3, H, W).
y: A target tensor, (N, 3, H, W).
value_range: The value range of the inputs (usually 1. or 255).
kernel: A 2D gradient kernel, (2, 1, K, K).
If `None`, use the Prewitt kernel instead.
For the remaining arguments, refer to [1].
`**kwargs` are transmitted to `_gmsd`.
Example:
>>> x = torch.rand(5, 3, 256, 256)
Expand All @@ -41,70 +89,67 @@ def gmsd(
torch.Size([5])
"""

_, _, h, w = x.size()

# Downsample
padding = (0, w % 2, 0, h % 2)

if sum(padding) > 0:
x = F.pad(x, pad=padding)
y = F.pad(y, pad=padding)

x = F.avg_pool2d(x, kernel_size=2)
y = F.avg_pool2d(y, kernel_size=2)
x = F.avg_pool2d(x, kernel_size=2, ceil_mode=True)
y = F.avg_pool2d(y, kernel_size=2, ceil_mode=True)

# RGB to luminance
l_weights = _L_WEIGHTS.to(x.device).view(1, 3, 1, 1)
l_weights /= value_range

x = F.conv2d(x, l_weights)
y = F.conv2d(y, l_weights)

# Gradient magnitude
kernel = prewitt_kernel()
kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1).to(x.device)
# Kernel
if kernel is None:
kernel = prewitt_kernel()
kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1)
kernel = kernel.to(x.device)

gm_x = tensor_norm(filter2d(x, kernel, padding=1), dim=1)
gm_y = tensor_norm(filter2d(y, kernel, padding=1), dim=1)

# Gradient magnitude similarity
gms = (2. * gm_x * gm_y + c) / (gm_x ** 2 + gm_y ** 2 + c)

# Gradient magnitude similarity diviation
gmsd = (gms - gms.mean((-1, -2), keepdim=True)) ** 2
gmsd = torch.sqrt(gmsd.mean((-1, -2)))

return gmsd
return _gmsd(x, y, kernel, **kwargs)


class GMSD(nn.Module):
r"""Creates a criterion that measures the GMSD
between an input and a target.
Args:
kernel: A 2D gradient kernel, (2, 1, K, K).
If `None`, use the Prewitt kernel instead.
reduction: Specifies the reduction to apply to the output:
`'none'` | `'mean'` | `'sum'`.
`**kwargs` are transmitted to `gmsd`.
`**kwargs` are transmitted to `_gmsd`.
Shape:
* Input: (N, 3, H, W)
* Target: (N, 3, H, W)
* Output: (N,) or (1,) depending on `reduction`
Example:
>>> criterion = GMSD()
>>> criterion = GMSD().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
"""

def __init__(self, reduction: str = 'mean', **kwargs):
def __init__(
self,
kernel: torch.Tensor = None,
reduction: str = 'mean',
**kwargs,
):
r""""""
super().__init__()

if kernel is None:
kernel = prewitt_kernel()
kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1)

self.register_buffer('kernel', kernel)
self.register_buffer('l_weights', _L_WEIGHTS.view(1, 3, 1, 1))

self.reduce = build_reduce(reduction)
self.kwargs = kwargs

Expand All @@ -116,6 +161,15 @@ def forward(
r"""Defines the computation performed at every call.
"""

l = gmsd(input, target, **self.kwargs)
# Downsample
input = F.avg_pool2d(input, 2, ceil_mode=True)
target = F.avg_pool2d(target, 2, ceil_mode=True)

# RGB to luminance
input = F.conv2d(input, self.l_weights)
target = F.conv2d(target, self.l_weights)

# GMSD
l = _gmsd(input, target, self.kernel, **self.kwargs)

return self.reduce(l)
129 changes: 95 additions & 34 deletions piqa/mdsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
from piqa.utils import build_reduce, prewitt_kernel, filter2d, tensor_norm

_LHM_WEIGHTS = torch.FloatTensor([
[0.2989, 0.587, 0.114],
[0.299, 0.587, 0.114],
[0.3, 0.04, -0.35],
[0.34, -0.6, 0.17],
])


def mdsi(
def _mdsi(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor,
value_range: float = 1.,
combination: str = 'sum',
c1: float = 0.00215, # 140. / (255. ** 2)
Expand All @@ -37,11 +38,15 @@ def mdsi(
q: float = 0.25,
o: float = 0.25,
) -> torch.Tensor:
r"""Returns the MDSI between `x` and `y`.
r"""Returns the MDSI between `x` and `y`,
without downsampling and color space conversion.
`_mdsi` is an auxiliary function for `mdsi` and `MDSI`.
Args:
x: An input tensor, (N, 3, H, W).
y: A target tensor, (N, 3, H, W).
kernel: A 2D gradient kernel, (2, 1, K, K).
value_range: The value range of the inputs (usually 1. or 255).
combination: Specifies the scheme to combine the gradient
and chromaticity similarities (GS, CS):
Expand All @@ -50,41 +55,25 @@ def mdsi(
For the remaining arguments, refer to [1].
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = mdsi(x, y)
>>> x = torch.rand(5, 1, 256, 256)
>>> y = torch.rand(5, 1, 256, 256)
>>> kernel = torch.rand(2, 1, 3, 3)
>>> l = _mdsi(x, y, kernel)
>>> l.size()
torch.Size([5])
"""

_, _, h, w = x.size()

# Downsample
M = max(1, min(h, w) // 256)
padding = (0, M - (w - 1 % M) + 1, 0, M - (h - 1 % M) + 1)

if sum(padding) > 0:
x = F.pad(x, pad=padding)
y = F.pad(y, pad=padding)

x = F.avg_pool2d(x, kernel_size=M)
y = F.avg_pool2d(y, kernel_size=M)

# RGB to LHM
lhm_weights = _LHM_WEIGHTS.to(x.device).view(3, 3, 1, 1)
lhm_weights /= value_range

x = F.conv2d(x, lhm_weights)
y = F.conv2d(y, lhm_weights)
c1 *= value_range ** 2
c2 *= value_range ** 2
c3 *= value_range ** 2

# Gradient magnitude
kernel = prewitt_kernel()
kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1).to(x.device)
pad = kernel.size(-1) // 2

gm_x = tensor_norm(filter2d(x[:, :1], kernel, padding=1), dim=1)
gm_y = tensor_norm(filter2d(y[:, :1], kernel, padding=1), dim=1)
gm_x = tensor_norm(filter2d(x[:, :1], kernel, padding=pad), dim=1)
gm_y = tensor_norm(filter2d(y[:, :1], kernel, padding=pad), dim=1)
gm_avg = tensor_norm(
filter2d((x + y)[:, :1] / 2., kernel, padding=1),
filter2d((x + y)[:, :1] / 2., kernel, padding=pad),
dim=1,
)

Expand Down Expand Up @@ -120,34 +109,94 @@ def mdsi(
return mds


def mdsi(
x: torch.Tensor,
y: torch.Tensor,
kernel: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
r"""Returns the MDSI between `x` and `y`.
Args:
x: An input tensor, (N, 3, H, W).
y: A target tensor, (N, 3, H, W).
kernel: A 2D gradient kernel, (2, 1, K, K).
If `None`, use the Prewitt kernel instead.
`**kwargs` are transmitted to `_mdsi`.
Example:
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> l = mdsi(x, y)
>>> l.size()
torch.Size([5])
"""

# Downsample
_, _, h, w = x.size()
M = max(1, min(h, w) // 256)

x = F.avg_pool2d(x, kernel_size=M, ceil_mode=True)
y = F.avg_pool2d(y, kernel_size=M, ceil_mode=True)

# RGB to LHM
lhm_weights = _LHM_WEIGHTS.to(x.device).view(3, 3, 1, 1)

x = F.conv2d(x, lhm_weights)
y = F.conv2d(y, lhm_weights)

# Kernel
if kernel is None:
kernel = prewitt_kernel()
kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1)
kernel = kernel.to(x.device)

return _mdsi(x, y, kernel, **kwargs)


class MDSI(nn.Module):
r"""Creates a criterion that measures the MDSI
between an input and a target.
Args:
kernel: A 2D gradient kernel, (2, 1, K, K).
If `None`, use the Prewitt kernel instead.
reduction: Specifies the reduction to apply to the output:
`'none'` | `'mean'` | `'sum'`.
`**kwargs` are transmitted to `mdsi`.
`**kwargs` are transmitted to `_mdsi`.
Shape:
* Input: (N, 3, H, W)
* Target: (N, 3, H, W)
* Output: (N,) or (1,) depending on `reduction`
Example:
>>> criterion = MDSI()
>>> criterion = MDSI().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
"""

def __init__(self, reduction: str = 'mean', **kwargs):
def __init__(
self,
kernel: torch.Tensor = None,
reduction: str = 'mean',
**kwargs,
):
r""""""
super().__init__()

if kernel is None:
kernel = prewitt_kernel()
kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1)

self.register_buffer('kernel', kernel)
self.register_buffer('lhm_weights', _LHM_WEIGHTS.view(3, 3, 1, 1))

self.reduce = build_reduce(reduction)
self.kwargs = kwargs

Expand All @@ -159,6 +208,18 @@ def forward(
r"""Defines the computation performed at every call.
"""

l = mdsi(input, target, **self.kwargs)
# Downsample
_, _, h, w = input.size()
M = max(1, min(h, w) // 256)

input = F.avg_pool2d(input, kernel_size=M, ceil_mode=True)
target = F.avg_pool2d(target, kernel_size=M, ceil_mode=True)

# RGB to LHM
input = F.conv2d(input, self.lhm_weights)
target = F.conv2d(target, self.lhm_weights)

# MDSI
l = _mdsi(input, target, self.kernel, **self.kwargs)

return self.reduce(l)
Loading

0 comments on commit e6179ae

Please sign in to comment.