diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 0000000..47b1747 --- /dev/null +++ b/.style.yapf @@ -0,0 +1,10 @@ +[style] +based_on_style = google + +spaces_around_power_operator = true + +dedent_closing_brackets = true +coalesce_brackets = true +split_arguments_when_comma_terminated = true +split_all_top_level_comma_separated_values = true +split_complex_comprehension = true diff --git a/README.md b/README.md index ddb6c9b..578e071 100644 --- a/README.md +++ b/README.md @@ -47,3 +47,5 @@ The [documentation](https://francois-rozet.github.io/spiq/) of this package is g ```bash pdoc spiq --html --config "git_link_template='https://github.com/francois-rozet/spiq/blob/{commit}/{path}#L{start_line}-L{end_line}'" ``` + +> The code follows the [Google Python style](https://google.github.io/styleguide/pyguide.html) and is compliant with [YAPF](https://github.com/google/yapf). diff --git a/spiq/__init__.py b/spiq/__init__.py index bfa505c..bf3ecad 100644 --- a/spiq/__init__.py +++ b/spiq/__init__.py @@ -1,6 +1,8 @@ r"""Simple PyTorch Image Quality -The spiq package is divided in several submodules, each of which implements the functions and/or classes related to a specific image quality metric. +The spiq package is divided in several submodules, each of +which implements the functions and/or classes related to a +specific image quality metric. """ __version__ = '0.0.2' diff --git a/spiq/lpips.py b/spiq/lpips.py index 97ae9d8..8b02899 100644 --- a/spiq/lpips.py +++ b/spiq/lpips.py @@ -11,10 +11,6 @@ https://arxiv.org/abs/1801.03924 """ -########### -# Imports # -########### - import inspect import os import torch @@ -23,32 +19,30 @@ from spiq.utils import normalize_tensor, Intermediary - -############# -# Constants # -############# - _SHIFT = torch.Tensor([0.485, 0.456, 0.406]) _SCALE = torch.Tensor([0.229, 0.224, 0.225]) -########### -# Classes # -########### - class LPIPS(nn.Module): - r"""Creates a criterion that measures the LPIPS between an input and a target. + r"""Creates a criterion that measures the LPIPS + between an input and a target. Args: - network: perception network name (`'AlexNet'`, `'SqueezeNet'` or `'VGG16'`) - scaling: whether the input and target are sclaed w.r.t. ImageNet - reduction: reduction type (`'mean'`, `'sum'` or `'none'`) + network: A perception network name (`'AlexNet'`, + `'SqueezeNet'` or `'VGG16'`). + scaling: Whether the input and target are scaled w.r.t. ImageNet. + reduction: A reduction type (`'mean'`, `'sum'` or `'none'`). Call: The input and target tensors should be of shape (N, C, H, W). """ - def __init__(self, network: str = 'AlexNet', scaling: bool = False, reduction: str = 'mean'): + def __init__( + self, + network: str = 'AlexNet', + scaling: bool = False, + reduction: str = 'mean', + ): super().__init__() # ImageNet scaling @@ -93,14 +87,18 @@ def __init__(self, network: str = 'AlexNet', scaling: bool = False, reduction: s self.reduction = reduction - def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: if self.scaling: input = (input - self.shift) / self.scale target = (target - self.shift) / self.scale residuals = [] - for loss, (fx, fy) in zip(self.lin, zip(self.net(input), self.net(target))): + for loss, fx, fy in zip(self.lin, self.net(input), self.net(target)): fx = normalize_tensor(fx, dim=1, norm='L2') fy = normalize_tensor(fy, dim=1, norm='L2') diff --git a/spiq/psnr.py b/spiq/psnr.py index a837e47..df976a1 100644 --- a/spiq/psnr.py +++ b/spiq/psnr.py @@ -6,46 +6,42 @@ https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio """ -########### -# Imports # -########### - import torch import torch.nn as nn from typing import Tuple -############# -# Functions # -############# - -def psnr(x: torch.Tensor, y: torch.Tensor, dim: Tuple[int, ...] = (), keepdim: bool = False, value_range: float = 1., epsilon: float = 1e-8) -> torch.Tensor: +def psnr( + x: torch.Tensor, + y: torch.Tensor, + dim: Tuple[int, ...] = (), + keepdim: bool = False, + value_range: float = 1., + epsilon: float = 1e-8, +) -> torch.Tensor: r"""Returns the PSNR between `x` and `y`. Args: - x: input tensor - y: target tensor - dim: dimension(s) along which to average - keepdim: whether the output tensor has `dim` retained or not - value_range: value range of the inputs (usually 1. or 255) - epsilon: numerical stability term + x: An input tensor. + y: A target tensor. + dim: The dimension(s) along which to average. + keepdim: Whether the output tensor has `dim` retained or not. + value_range: The value range of the inputs (usually 1. or 255). + epsilon: A numerical stability term. """ mse = ((x - y) ** 2).mean(dim=dim, keepdim=keepdim) + epsilon return 10 * torch.log10(value_range ** 2 / mse) -########### -# Classes # -########### - class PSNR(nn.Module): - r"""Creates a criterion that measures the PSNR between an input and a target. + r"""Creates a criterion that measures the PSNR + between an input and a target. Args: - value_range: value range of the inputs (usually 1. or 255) - reduction: reduction type (`'mean'`, `'sum'` or `'none'`) + value_range: The value range of the inputs (usually 1. or 255). + reduction: A reduction type (`'mean'`, `'sum'` or `'none'`). Call: The input and target tensors should be of shape (N, ...). @@ -57,11 +53,16 @@ def __init__(self, value_range: float = 1., reduction: str = 'mean'): self.value_range = value_range self.reduction = reduction - def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: l = psnr( - input, target, + input, + target, dim=tuple(range(1, input.ndimension())), - value_range=self.value_range + value_range=self.value_range, ) if self.reduction == 'mean': diff --git a/spiq/ssim.py b/spiq/ssim.py index 9a6e1d7..657a5fb 100644 --- a/spiq/ssim.py +++ b/spiq/ssim.py @@ -18,36 +18,23 @@ https://ieeexplore.ieee.org/abstract/document/1284395/ """ -########### -# Imports # -########### - import torch import torch.nn as nn import torch.nn.functional as F from spiq.utils import gaussian_kernel - -############# -# Constants # -############# - _SIGMA = 1.5 _K1, _K2 = 0.01, 0.03 _WEIGHTS = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) -############# -# Functions # -############# - def create_window(window_size: int, n_channels: int) -> torch.Tensor: r"""Returns the SSIM convolution window (kernel) of size `window_size`. Args: - window_size: size of the window - n_channels: number of channels + window_size: The size of the window. + n_channels: A number of channels. """ kernel = gaussian_kernel(window_size, _SIGMA) @@ -58,14 +45,20 @@ def create_window(window_size: int, n_channels: int) -> torch.Tensor: return window -def ssim_per_channel(x: torch.Tensor, y: torch.Tensor, window: torch.Tensor, value_range: float = 1.) -> torch.Tensor: - r"""Returns the SSIM and the contrast sensitivity (CS) per channel between `x` and `y`. +def ssim_per_channel( + x: torch.Tensor, + y: torch.Tensor, + window: torch.Tensor, + value_range: float = 1., +) -> torch.Tensor: + r"""Returns the SSIM and the contrast sensitivity (CS) + per channel between `x` and `y`. Args: - x: input tensor, (N, C, H, W) - y: target tensor, (N, C, H, W) - window: convolution window - value_range: value range of the inputs (usually 1. or 255) + x: An input tensor, (N, C, H, W). + y: A target tensor, (N, C, H, W). + window: A convolution window. + value_range: The value range of the inputs (usually 1. or 255). """ n_channels, _, window_size, _ = window.size() @@ -77,9 +70,12 @@ def ssim_per_channel(x: torch.Tensor, y: torch.Tensor, window: torch.Tensor, val mu_y_sq = mu_y ** 2 mu_xy = mu_x * mu_y - sigma_x_sq = F.conv2d(x ** 2, window, padding=0, groups=n_channels) - mu_x_sq - sigma_y_sq = F.conv2d(y ** 2, window, padding=0, groups=n_channels) - mu_y_sq - sigma_xy = F.conv2d(x * y, window, padding=0, groups=n_channels) - mu_xy + sigma_x_sq = F.conv2d(x ** 2, window, padding=0, groups=n_channels) + sigma_x_sq -= mu_x_sq + sigma_y_sq = F.conv2d(y ** 2, window, padding=0, groups=n_channels) + sigma_y_sq -= mu_y_sq + sigma_xy = F.conv2d(x * y, window, padding=0, groups=n_channels) + sigma_xy -= mu_xy c1 = (_K1 * value_range) ** 2 c2 = (_K2 * value_range) ** 2 @@ -90,14 +86,19 @@ def ssim_per_channel(x: torch.Tensor, y: torch.Tensor, window: torch.Tensor, val return ssim_map.mean((-1, -2)), cs_map.mean((-1, -2)) -def ssim(x: torch.Tensor, y: torch.Tensor, window_size: int = 11, value_range: float = 1.) -> torch.Tensor: +def ssim( + x: torch.Tensor, + y: torch.Tensor, + window_size: int = 11, + value_range: float = 1., +) -> torch.Tensor: r"""Returns the SSIM between `x` and `y`. Args: - x: input tensor, (N, C, H, W) - y: target tensor, (N, C, H, W) - window_size: size of the window - value_range: value range of the inputs (usually 1. or 255) + x: An input tensor, (N, C, H, W). + y: A target tensor, (N, C, H, W). + window_size: The size of the window. + value_range: The value range of the inputs (usually 1. or 255). """ n_channels = x.size(1) @@ -106,15 +107,21 @@ def ssim(x: torch.Tensor, y: torch.Tensor, window_size: int = 11, value_range: f return ssim_per_channel(x, y, window, value_range)[0].mean(-1) -def msssim_per_channel(x: torch.Tensor, y: torch.Tensor, window: torch.Tensor, value_range: float = 1., weights: torch.Tensor = _WEIGHTS) -> torch.Tensor: +def msssim_per_channel( + x: torch.Tensor, + y: torch.Tensor, + window: torch.Tensor, + value_range: float = 1., + weights: torch.Tensor = _WEIGHTS, +) -> torch.Tensor: """Returns the MS-SSIM per channel between `x` and `y`. Args: - x: input tensor, (N, C, H, W) - y: target tensor, (N, C, H, W) - window: convolution window - value_range: value range of the inputs (usually 1. or 255) - weights: weights of the scales, (M,) + x: An input tensor, (N, C, H, W). + y: A target tensor, (N, C, H, W). + window: A convolution window. + value_range: The value range of the inputs (usually 1. or 255). + weights: The weights of the scales, (M,). """ mcs = [] @@ -134,15 +141,21 @@ def msssim_per_channel(x: torch.Tensor, y: torch.Tensor, window: torch.Tensor, v return msssim.prod(dim=0) -def msssim(x: torch.Tensor, y: torch.Tensor, window_size: int = 11, value_range: float = 1., weights: torch.Tensor = _WEIGHTS) -> torch.Tensor: +def msssim( + x: torch.Tensor, + y: torch.Tensor, + window_size: int = 11, + value_range: float = 1., + weights: torch.Tensor = _WEIGHTS, +) -> torch.Tensor: r"""Returns the MS-SSIM between `x` and `y`. Args: - x: input tensor, (N, C, H, W) - y: target tensor, (N, C, H, W) - window_size: size of the window - value_range: value range of the inputs (usually 1. or 255) - weights: weights of the scales, (M,) + x: An input tensor, (N, C, H, W). + y: A target tensor, (N, C, H, W). + window_size: The size of the window. + value_range: The value range of the inputs (usually 1. or 255). + weights: The weights of the scales, (M,). """ n_channels = x.size(1) @@ -152,39 +165,44 @@ def msssim(x: torch.Tensor, y: torch.Tensor, window_size: int = 11, value_range: return msssim_per_channel(x, y, window, value_range, weights).mean(-1) -########### -# Classes # -########### - class SSIM(nn.Module): - r"""Creates a criterion that measures the SSIM between an input and a target. + r"""Creates a criterion that measures the SSIM + between an input and a target. Args: - window_size: size of the window - n_channels: number of channels - value_range: value range of the inputs (usually 1. or 255) - reduction: reduction type (`'mean'`, `'sum'` or `'none'`) + window_size: The size of the window. + n_channels: A number of channels. + value_range: The value range of the inputs (usually 1. or 255). + reduction: A reduction type (`'mean'`, `'sum'` or `'none'`). Call: The input and target tensors should be of shape (N, C, H, W). """ - def __init__(self, window_size: int = 11, n_channels: int = 3, value_range: float = 1., reduction: str = 'mean'): + def __init__( + self, + window_size: int = 11, + n_channels: int = 3, + value_range: float = 1., + reduction: str = 'mean', + ): super().__init__() - self.register_buffer( - 'window', - create_window(window_size, n_channels) - ) + self.register_buffer('window', create_window(window_size, n_channels)) self.value_range = value_range self.reduction = reduction - def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: l = ssim_per_channel( - input, target, + input, + target, window=self.window, - value_range=self.value_range + value_range=self.value_range, )[0].mean(-1) if self.reduction == 'mean': @@ -196,30 +214,34 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: class MSSSIM(SSIM): - r"""Creates a criterion that measures the MS-SSIM between an input and a target. + r"""Creates a criterion that measures the MS-SSIM + between an input and a target. Args: - window_size: size of the window - n_channels: number of channels - value_range: value range of the inputs (usually 1. or 255) - weights: weights of the scales, (M,) - reduction: reduction type (`'mean'`, `'sum'` or `'none'`) + weights: The weights of the scales, (M,). + + All other arguments are inherited (see `SSIM`). Call: The input and target tensors should be of shape (N, C, H, W). """ - def __init__(self, window_size: int = 11, n_channels: int = 3, value_range: float = 1., weights: torch.Tensor = _WEIGHTS, reduction: str = 'mean'): - super().__init__(window_size, n_channels, value_range, reduction) + def __init__(self, weights: torch.Tensor = _WEIGHTS, *args, **kwargs): + super().__init__(*args, **kwargs) self.register_buffer('weights', weights) - def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: l = msssim_per_channel( - input, target, + input, + target, window=self.window, value_range=self.value_range, - weights=self.weights + weights=self.weights, ).mean(-1) if self.reduction == 'mean': diff --git a/spiq/tv.py b/spiq/tv.py index b51bcd3..e318f94 100644 --- a/spiq/tv.py +++ b/spiq/tv.py @@ -6,24 +6,16 @@ https://en.wikipedia.org/wiki/Total_variation """ -########### -# Imports # -########### - import torch import torch.nn as nn -############# -# Functions # -############# - def tv(x: torch.Tensor, norm: str = 'L2') -> torch.Tensor: r"""Returns the TV of `x`. Args: - x: input tensor, (..., C, H, W) - norm: norm function name (`'L1'`, `'L2'` or `'L2_squared'`) + x: An input tensor, (..., C, H, W). + norm: A norm function name (`'L1'`, `'L2'` or `'L2_squared'`). """ w_var = x[..., :, 1:] - x[..., :, :-1] @@ -32,7 +24,7 @@ def tv(x: torch.Tensor, norm: str = 'L2') -> torch.Tensor: if norm in ['L2', 'L2_squared']: w_var = w_var ** 2 h_var = h_var ** 2 - else: # norm == 'L1' + else: # norm == 'L1' w_var = w_var.abs() h_var = h_var.abs() @@ -44,16 +36,12 @@ def tv(x: torch.Tensor, norm: str = 'L2') -> torch.Tensor: return score -########### -# Classes # -########### - class TV(nn.Module): r"""Creates a criterion that measures the TV of an input. Args: - norm: norm function name (`'L1'`, `'L2'` or `'L2_squared'`) - reduction: reduction type (`'mean'`, `'sum'` or `'none'`) + norm: A norm function name (`'L1'`, `'L2'` or `'L2_squared'`). + reduction: A reduction type (`'mean'`, `'sum'` or `'none'`). Call: The input tensor should be of shape (N, C, H, W). diff --git a/spiq/utils.py b/spiq/utils.py index 4ee1654..6eedebd 100644 --- a/spiq/utils.py +++ b/spiq/utils.py @@ -1,29 +1,26 @@ r"""Miscellaneous tools such as modules, functionals and more. """ -########### -# Imports # -########### - import torch import torch.nn as nn from typing import List, Tuple -############# -# Functions # -############# - -def gaussian_kernel(kernel_size: int, sigma: float = 1., n: int = 2) -> torch.Tensor: +def gaussian_kernel( + kernel_size: int, + sigma: float = 1., + n: int = 2, +) -> torch.Tensor: r"""Returns the `n`-dimensional Gaussian kernel of size `kernel_size`. - The distribution is centered around the kernel's center and the standard deviation is `sigma`. + The distribution is centered around the kernel's center + and the standard deviation is `sigma`. Args: - kernel_size: size of the kernel - sigma: standard deviation of the distribution - n: number of dimensions of the kernel + kernel_size: The size of the kernel. + sigma: The standard deviation of the distribution. + n: The number of dimensions of the kernel. Wikipedia: https://en.wikipedia.org/wiki/Normal_distribution @@ -46,14 +43,19 @@ def gaussian_kernel(kernel_size: int, sigma: float = 1., n: int = 2) -> torch.Te return kernel -def tensor_norm(x: torch.Tensor, dim: Tuple[int, ...] = (), keepdim: bool = False, norm: str = 'L2') -> torch.Tensor: +def tensor_norm( + x: torch.Tensor, + dim: Tuple[int, ...] = (), + keepdim: bool = False, + norm: str = 'L2', +) -> torch.Tensor: r"""Returns the norm of `x`. Args: - x: input tensor - dim: dimension(s) along which to calculate the norm - keepdim: whether the output tensor has `dim` retained or not - norm: norm function name (`'L1'`, `'L2'` or `'L2_squared'`) + x: An input tensor. + dim: The dimension(s) along which to calculate the norm. + keepdim: Whether the output tensor has `dim` retained or not. + norm: A norm function name (`'L1'`, `'L2'` or `'L2_squared'`). Wikipedia: https://en.wikipedia.org/wiki/Norm_(mathematics) @@ -61,7 +63,7 @@ def tensor_norm(x: torch.Tensor, dim: Tuple[int, ...] = (), keepdim: bool = Fals if norm in ['L2', 'L2_squared']: x = x ** 2 - else: # norm == 'L1' + else: # norm == 'L1' x = x.abs() x = x.sum(dim=dim, keepdim=keepdim) @@ -72,14 +74,19 @@ def tensor_norm(x: torch.Tensor, dim: Tuple[int, ...] = (), keepdim: bool = Fals return x -def normalize_tensor(x: torch.Tensor, dim: Tuple[int, ...] = (), norm: str = 'L2', epsilon: float = 1e-8) -> torch.Tensor: +def normalize_tensor( + x: torch.Tensor, + dim: Tuple[int, ...] = (), + norm: str = 'L2', + epsilon: float = 1e-8, +) -> torch.Tensor: r"""Returns `x` normalized. Args: - x: input tensor - dim: dimension(s) along which to normalize - norm: norm function name (`'L1'`, `'L2'` or `'L2_squared'`) - epsilon: numerical stability term + x: An input tensor. + dim: The dimension(s) along which to normalize. + norm: A norm function name (`'L1'`, `'L2'` or `'L2_squared'`). + epsilon: A numerical stability term. """ norm = tensor_norm(x, dim=dim, keepdim=True, norm=norm) @@ -87,16 +94,13 @@ def normalize_tensor(x: torch.Tensor, dim: Tuple[int, ...] = (), norm: str = 'L2 return x / (norm + epsilon) -########### -# Classes # -########### - class Intermediary(nn.Module): - r"""Module that catches and returns the outputs of indermediate target layers of a sequential module during its forward pass. + r"""Module that catches and returns the outputs of indermediate + target layers of a sequential module during its forward pass. Args: - layers: sequential module - targets: target layer indexes + layers: A sequential module. + targets: A list of target layer indexes. """ def __init__(self, layers: nn.Sequential, targets: List[int]):