From e6179aeaa4d5a917079dbd9b008151b087007145 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Thu, 10 Dec 2020 02:18:37 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Register=20convolution=20k?= =?UTF-8?q?ernels=20in=20buffer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- piqa/gmsd.py | 120 ++++++++++++++++++++++++++++++++++------------- piqa/mdsi.py | 129 +++++++++++++++++++++++++++++++++++++-------------- piqa/ssim.py | 60 ++++++++++++++++++------ 3 files changed, 228 insertions(+), 81 deletions(-) diff --git a/piqa/gmsd.py b/piqa/gmsd.py index 84ead52..5016e7e 100644 --- a/piqa/gmsd.py +++ b/piqa/gmsd.py @@ -15,23 +15,71 @@ from piqa.utils import build_reduce, prewitt_kernel, filter2d, tensor_norm -_L_WEIGHTS = torch.FloatTensor([0.2989, 0.587, 0.114]) +_L_WEIGHTS = torch.FloatTensor([0.299, 0.587, 0.114]) -def gmsd( +def _gmsd( x: torch.Tensor, y: torch.Tensor, + kernel: torch.Tensor, value_range: float = 1., c: float = 0.00261, # 170. / (255. ** 2) +) -> torch.Tensor: + r"""Returns the GMSD between `x` and `y`, + without downsampling and color space conversion. + + `_gmsd` is an auxiliary function for `gmsd` and `GMSD`. + + Args: + x: An input tensor, (N, 1, H, W). + y: A target tensor, (N, 1, H, W). + kernel: A 2D gradient kernel, (2, 1, K, K). + value_range: The value range of the inputs (usually 1. or 255). + + For the remaining arguments, refer to [1]. + + Example: + >>> x = torch.rand(5, 1, 256, 256) + >>> y = torch.rand(5, 1, 256, 256) + >>> kernel = torch.rand(2, 1, 3, 3) + >>> l = _gmsd(x, y, kernel) + >>> l.size() + torch.Size([5]) + """ + + c *= value_range ** 2 + + # Gradient magnitude + pad = kernel.size(-1) // 2 + + gm_x = tensor_norm(filter2d(x, kernel, padding=pad), dim=1) + gm_y = tensor_norm(filter2d(y, kernel, padding=pad), dim=1) + + # Gradient magnitude similarity + gms = (2. * gm_x * gm_y + c) / (gm_x ** 2 + gm_y ** 2 + c) + + # Gradient magnitude similarity diviation + gmsd = (gms - gms.mean((-1, -2), keepdim=True)) ** 2 + gmsd = torch.sqrt(gmsd.mean((-1, -2))) + + return gmsd + + +def gmsd( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor = None, + **kwargs, ) -> torch.Tensor: r"""Returns the GMSD between `x` and `y`. Args: x: An input tensor, (N, 3, H, W). y: A target tensor, (N, 3, H, W). - value_range: The value range of the inputs (usually 1. or 255). + kernel: A 2D gradient kernel, (2, 1, K, K). + If `None`, use the Prewitt kernel instead. - For the remaining arguments, refer to [1]. + `**kwargs` are transmitted to `_gmsd`. Example: >>> x = torch.rand(5, 3, 256, 256) @@ -41,40 +89,23 @@ def gmsd( torch.Size([5]) """ - _, _, h, w = x.size() - # Downsample - padding = (0, w % 2, 0, h % 2) - - if sum(padding) > 0: - x = F.pad(x, pad=padding) - y = F.pad(y, pad=padding) - - x = F.avg_pool2d(x, kernel_size=2) - y = F.avg_pool2d(y, kernel_size=2) + x = F.avg_pool2d(x, kernel_size=2, ceil_mode=True) + y = F.avg_pool2d(y, kernel_size=2, ceil_mode=True) # RGB to luminance l_weights = _L_WEIGHTS.to(x.device).view(1, 3, 1, 1) - l_weights /= value_range x = F.conv2d(x, l_weights) y = F.conv2d(y, l_weights) - # Gradient magnitude - kernel = prewitt_kernel() - kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1).to(x.device) + # Kernel + if kernel is None: + kernel = prewitt_kernel() + kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1) + kernel = kernel.to(x.device) - gm_x = tensor_norm(filter2d(x, kernel, padding=1), dim=1) - gm_y = tensor_norm(filter2d(y, kernel, padding=1), dim=1) - - # Gradient magnitude similarity - gms = (2. * gm_x * gm_y + c) / (gm_x ** 2 + gm_y ** 2 + c) - - # Gradient magnitude similarity diviation - gmsd = (gms - gms.mean((-1, -2), keepdim=True)) ** 2 - gmsd = torch.sqrt(gmsd.mean((-1, -2))) - - return gmsd + return _gmsd(x, y, kernel, **kwargs) class GMSD(nn.Module): @@ -82,10 +113,12 @@ class GMSD(nn.Module): between an input and a target. Args: + kernel: A 2D gradient kernel, (2, 1, K, K). + If `None`, use the Prewitt kernel instead. reduction: Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. - `**kwargs` are transmitted to `gmsd`. + `**kwargs` are transmitted to `_gmsd`. Shape: * Input: (N, 3, H, W) @@ -93,7 +126,7 @@ class GMSD(nn.Module): * Output: (N,) or (1,) depending on `reduction` Example: - >>> criterion = GMSD() + >>> criterion = GMSD().cuda() >>> x = torch.rand(5, 3, 256, 256).cuda() >>> y = torch.rand(5, 3, 256, 256).cuda() >>> l = criterion(x, y) @@ -101,10 +134,22 @@ class GMSD(nn.Module): torch.Size([]) """ - def __init__(self, reduction: str = 'mean', **kwargs): + def __init__( + self, + kernel: torch.Tensor = None, + reduction: str = 'mean', + **kwargs, + ): r"""""" super().__init__() + if kernel is None: + kernel = prewitt_kernel() + kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1) + + self.register_buffer('kernel', kernel) + self.register_buffer('l_weights', _L_WEIGHTS.view(1, 3, 1, 1)) + self.reduce = build_reduce(reduction) self.kwargs = kwargs @@ -116,6 +161,15 @@ def forward( r"""Defines the computation performed at every call. """ - l = gmsd(input, target, **self.kwargs) + # Downsample + input = F.avg_pool2d(input, 2, ceil_mode=True) + target = F.avg_pool2d(target, 2, ceil_mode=True) + + # RGB to luminance + input = F.conv2d(input, self.l_weights) + target = F.conv2d(target, self.l_weights) + + # GMSD + l = _gmsd(input, target, self.kernel, **self.kwargs) return self.reduce(l) diff --git a/piqa/mdsi.py b/piqa/mdsi.py index 21e19a4..c34c046 100644 --- a/piqa/mdsi.py +++ b/piqa/mdsi.py @@ -16,15 +16,16 @@ from piqa.utils import build_reduce, prewitt_kernel, filter2d, tensor_norm _LHM_WEIGHTS = torch.FloatTensor([ - [0.2989, 0.587, 0.114], + [0.299, 0.587, 0.114], [0.3, 0.04, -0.35], [0.34, -0.6, 0.17], ]) -def mdsi( +def _mdsi( x: torch.Tensor, y: torch.Tensor, + kernel: torch.Tensor, value_range: float = 1., combination: str = 'sum', c1: float = 0.00215, # 140. / (255. ** 2) @@ -37,11 +38,15 @@ def mdsi( q: float = 0.25, o: float = 0.25, ) -> torch.Tensor: - r"""Returns the MDSI between `x` and `y`. + r"""Returns the MDSI between `x` and `y`, + without downsampling and color space conversion. + + `_mdsi` is an auxiliary function for `mdsi` and `MDSI`. Args: x: An input tensor, (N, 3, H, W). y: A target tensor, (N, 3, H, W). + kernel: A 2D gradient kernel, (2, 1, K, K). value_range: The value range of the inputs (usually 1. or 255). combination: Specifies the scheme to combine the gradient and chromaticity similarities (GS, CS): @@ -50,41 +55,25 @@ def mdsi( For the remaining arguments, refer to [1]. Example: - >>> x = torch.rand(5, 3, 256, 256) - >>> y = torch.rand(5, 3, 256, 256) - >>> l = mdsi(x, y) + >>> x = torch.rand(5, 1, 256, 256) + >>> y = torch.rand(5, 1, 256, 256) + >>> kernel = torch.rand(2, 1, 3, 3) + >>> l = _mdsi(x, y, kernel) >>> l.size() torch.Size([5]) """ - _, _, h, w = x.size() - - # Downsample - M = max(1, min(h, w) // 256) - padding = (0, M - (w - 1 % M) + 1, 0, M - (h - 1 % M) + 1) - - if sum(padding) > 0: - x = F.pad(x, pad=padding) - y = F.pad(y, pad=padding) - - x = F.avg_pool2d(x, kernel_size=M) - y = F.avg_pool2d(y, kernel_size=M) - - # RGB to LHM - lhm_weights = _LHM_WEIGHTS.to(x.device).view(3, 3, 1, 1) - lhm_weights /= value_range - - x = F.conv2d(x, lhm_weights) - y = F.conv2d(y, lhm_weights) + c1 *= value_range ** 2 + c2 *= value_range ** 2 + c3 *= value_range ** 2 # Gradient magnitude - kernel = prewitt_kernel() - kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1).to(x.device) + pad = kernel.size(-1) // 2 - gm_x = tensor_norm(filter2d(x[:, :1], kernel, padding=1), dim=1) - gm_y = tensor_norm(filter2d(y[:, :1], kernel, padding=1), dim=1) + gm_x = tensor_norm(filter2d(x[:, :1], kernel, padding=pad), dim=1) + gm_y = tensor_norm(filter2d(y[:, :1], kernel, padding=pad), dim=1) gm_avg = tensor_norm( - filter2d((x + y)[:, :1] / 2., kernel, padding=1), + filter2d((x + y)[:, :1] / 2., kernel, padding=pad), dim=1, ) @@ -120,15 +109,63 @@ def mdsi( return mds +def mdsi( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor = None, + **kwargs, +) -> torch.Tensor: + r"""Returns the MDSI between `x` and `y`. + + Args: + x: An input tensor, (N, 3, H, W). + y: A target tensor, (N, 3, H, W). + kernel: A 2D gradient kernel, (2, 1, K, K). + If `None`, use the Prewitt kernel instead. + + `**kwargs` are transmitted to `_mdsi`. + + Example: + >>> x = torch.rand(5, 3, 256, 256) + >>> y = torch.rand(5, 3, 256, 256) + >>> l = mdsi(x, y) + >>> l.size() + torch.Size([5]) + """ + + # Downsample + _, _, h, w = x.size() + M = max(1, min(h, w) // 256) + + x = F.avg_pool2d(x, kernel_size=M, ceil_mode=True) + y = F.avg_pool2d(y, kernel_size=M, ceil_mode=True) + + # RGB to LHM + lhm_weights = _LHM_WEIGHTS.to(x.device).view(3, 3, 1, 1) + + x = F.conv2d(x, lhm_weights) + y = F.conv2d(y, lhm_weights) + + # Kernel + if kernel is None: + kernel = prewitt_kernel() + kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1) + kernel = kernel.to(x.device) + + return _mdsi(x, y, kernel, **kwargs) + + class MDSI(nn.Module): r"""Creates a criterion that measures the MDSI between an input and a target. Args: + kernel: A 2D gradient kernel, (2, 1, K, K). + If `None`, use the Prewitt kernel instead. reduction: Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. - `**kwargs` are transmitted to `mdsi`. + `**kwargs` are transmitted to `_mdsi`. Shape: * Input: (N, 3, H, W) @@ -136,7 +173,7 @@ class MDSI(nn.Module): * Output: (N,) or (1,) depending on `reduction` Example: - >>> criterion = MDSI() + >>> criterion = MDSI().cuda() >>> x = torch.rand(5, 3, 256, 256).cuda() >>> y = torch.rand(5, 3, 256, 256).cuda() >>> l = criterion(x, y) @@ -144,10 +181,22 @@ class MDSI(nn.Module): torch.Size([]) """ - def __init__(self, reduction: str = 'mean', **kwargs): + def __init__( + self, + kernel: torch.Tensor = None, + reduction: str = 'mean', + **kwargs, + ): r"""""" super().__init__() + if kernel is None: + kernel = prewitt_kernel() + kernel = torch.stack([kernel, kernel.t()]).unsqueeze(1) + + self.register_buffer('kernel', kernel) + self.register_buffer('lhm_weights', _LHM_WEIGHTS.view(3, 3, 1, 1)) + self.reduce = build_reduce(reduction) self.kwargs = kwargs @@ -159,6 +208,18 @@ def forward( r"""Defines the computation performed at every call. """ - l = mdsi(input, target, **self.kwargs) + # Downsample + _, _, h, w = input.size() + M = max(1, min(h, w) // 256) + + input = F.avg_pool2d(input, kernel_size=M, ceil_mode=True) + target = F.avg_pool2d(target, kernel_size=M, ceil_mode=True) + + # RGB to LHM + input = F.conv2d(input, self.lhm_weights) + target = F.conv2d(target, self.lhm_weights) + + # MDSI + l = _mdsi(input, target, self.kernel, **self.kwargs) return self.reduce(l) diff --git a/piqa/ssim.py b/piqa/ssim.py index dcbe732..a59a150 100644 --- a/piqa/ssim.py +++ b/piqa/ssim.py @@ -68,7 +68,7 @@ def ssim_per_channel( Args: x: An input tensor, (N, C, H, W). y: A target tensor, (N, C, H, W). - window: A convolution window. + window: A convolution window, (C, 1, K, K). value_range: The value range of the inputs (usually 1. or 255). For the remaining arguments, refer to [1]. @@ -140,6 +140,7 @@ def msssim_per_channel( x: torch.Tensor, y: torch.Tensor, window: torch.Tensor, + weights: torch.Tensor, **kwargs, ) -> torch.Tensor: """Returns the MS-SSIM per channel between `x` and `y`. @@ -147,7 +148,8 @@ def msssim_per_channel( Args: x: An input tensor, (N, C, H, W). y: A target tensor, (N, C, H, W). - window: A convolution window. + window: A convolution window, (C, 1, K, K). + weights: The weights of the scales, (M,). `**kwargs` are transmitted to `ssim_per_channel`. @@ -155,25 +157,23 @@ def msssim_per_channel( >>> x = torch.rand(5, 3, 256, 256) >>> y = torch.rand(5, 3, 256, 256) >>> window = create_window(7, 3) - >>> l = msssim_per_channel(x, y, window) + >>> weights = torch.rand(5) + >>> l = msssim_per_channel(x, y, window, weights) >>> l.size() torch.Size([5, 3]) """ - weights = _WEIGHTS.to(x.device) - - mcs = [] + css = [] for i in range(weights.numel()): if i > 0: - padding = (x.shape[-2] % 2, x.shape[-1] % 2) - x = F.avg_pool2d(x, kernel_size=2, padding=padding) - y = F.avg_pool2d(y, kernel_size=2, padding=padding) + x = F.avg_pool2d(x, kernel_size=2, ceil_mode=True) + y = F.avg_pool2d(y, kernel_size=2, ceil_mode=True) ss, cs = ssim_per_channel(x, y, window, **kwargs) - mcs.append(torch.relu(cs)) + css.append(torch.relu(cs)) - msss = torch.stack(mcs[:-1] + [ss], dim=-1) + msss = torch.stack(css[:-1] + [ss], dim=-1) msss = (msss ** weights).prod(dim=-1) return msss @@ -183,6 +183,7 @@ def msssim( x: torch.Tensor, y: torch.Tensor, window_size: int = 11, + weights: torch.Tensor = None, **kwargs, ) -> torch.Tensor: r"""Returns the MS-SSIM between `x` and `y`. @@ -191,6 +192,8 @@ def msssim( x: An input tensor, (N, C, H, W). y: A target tensor, (N, C, H, W). window_size: The size of the window. + weights: The weights of the scales, (M,). + If `None`, use the official weights instead. `**kwargs` are transmitted to `msssim_per_channel`. @@ -205,7 +208,10 @@ def msssim( n_channels = x.size(1) window = create_window(window_size, n_channels).to(x.device) - return msssim_per_channel(x, y, window, **kwargs).mean(-1) + if weights is None: + weights = _WEIGHTS.to(x.device) + + return msssim_per_channel(x, y, window, weights, **kwargs).mean(-1) class SSIM(nn.Module): @@ -267,12 +273,17 @@ def forward( return self.reduce(l) -class MSSSIM(SSIM): +class MSSSIM(nn.Module): r"""Creates a criterion that measures the MS-SSIM between an input and a target. Args: - All arguments are inherited from `SSIM`. + window_size: The size of the window. + n_channels: A number of channels. + weights: The weights of the scales, (M,). + If `None`, use the official weights instead. + reduction: Specifies the reduction to apply to the output: + `'none'` | `'mean'` | `'sum'`. `**kwargs` are transmitted to `msssim_per_channel`. @@ -290,6 +301,26 @@ class MSSSIM(SSIM): torch.Size([]) """ + def __init__( + self, + window_size: int = 11, + n_channels: int = 3, + weights: torch.Tensor = None, + reduction: str = 'mean', + **kwargs, + ): + r"""""" + super().__init__() + + if weights is None: + weights = _WEIGHTS + + self.register_buffer('window', create_window(window_size, n_channels)) + self.register_buffer('weights', weights) + + self.reduce = build_reduce(reduction) + self.kwargs = kwargs + def forward( self, input: torch.Tensor, @@ -302,6 +333,7 @@ def forward( input, target, window=self.window, + weights=self.weights, **self.kwargs, ).mean(-1)