Skip to content

Commit

Permalink
📝 Adopting Google Python style
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Dec 4, 2020
1 parent fd8705f commit 7a83262
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 163 deletions.
10 changes: 10 additions & 0 deletions .style.yapf
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[style]
based_on_style = google

spaces_around_power_operator = true

dedent_closing_brackets = true
coalesce_brackets = true
split_arguments_when_comma_terminated = true
split_all_top_level_comma_separated_values = true
split_complex_comprehension = true
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@ The [documentation](https://francois-rozet.github.io/spiq/) of this package is g
```bash
pdoc spiq --html --config "git_link_template='https://github.com/francois-rozet/spiq/blob/{commit}/{path}#L{start_line}-L{end_line}'"
```

> The code follows the [Google Python style](https://google.github.io/styleguide/pyguide.html) and is compliant with [YAPF](https://github.com/google/yapf).
4 changes: 3 additions & 1 deletion spiq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
r"""Simple PyTorch Image Quality
The spiq package is divided in several submodules, each of which implements the functions and/or classes related to a specific image quality metric.
The spiq package is divided in several submodules, each of
which implements the functions and/or classes related to a
specific image quality metric.
"""

__version__ = '0.0.2'
38 changes: 18 additions & 20 deletions spiq/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
https://arxiv.org/abs/1801.03924
"""

###########
# Imports #
###########

import inspect
import os
import torch
Expand All @@ -23,32 +19,30 @@

from spiq.utils import normalize_tensor, Intermediary


#############
# Constants #
#############

_SHIFT = torch.Tensor([0.485, 0.456, 0.406])
_SCALE = torch.Tensor([0.229, 0.224, 0.225])


###########
# Classes #
###########

class LPIPS(nn.Module):
r"""Creates a criterion that measures the LPIPS between an input and a target.
r"""Creates a criterion that measures the LPIPS
between an input and a target.
Args:
network: perception network name (`'AlexNet'`, `'SqueezeNet'` or `'VGG16'`)
scaling: whether the input and target are sclaed w.r.t. ImageNet
reduction: reduction type (`'mean'`, `'sum'` or `'none'`)
network: A perception network name (`'AlexNet'`,
`'SqueezeNet'` or `'VGG16'`).
scaling: Whether the input and target are scaled w.r.t. ImageNet.
reduction: A reduction type (`'mean'`, `'sum'` or `'none'`).
Call:
The input and target tensors should be of shape (N, C, H, W).
"""

def __init__(self, network: str = 'AlexNet', scaling: bool = False, reduction: str = 'mean'):
def __init__(
self,
network: str = 'AlexNet',
scaling: bool = False,
reduction: str = 'mean',
):
super().__init__()

# ImageNet scaling
Expand Down Expand Up @@ -93,14 +87,18 @@ def __init__(self, network: str = 'AlexNet', scaling: bool = False, reduction: s

self.reduction = reduction

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
if self.scaling:
input = (input - self.shift) / self.scale
target = (target - self.shift) / self.scale

residuals = []

for loss, (fx, fy) in zip(self.lin, zip(self.net(input), self.net(target))):
for loss, fx, fy in zip(self.lin, self.net(input), self.net(target)):
fx = normalize_tensor(fx, dim=1, norm='L2')
fy = normalize_tensor(fy, dim=1, norm='L2')

Expand Down
51 changes: 26 additions & 25 deletions spiq/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,42 @@
https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
"""

###########
# Imports #
###########

import torch
import torch.nn as nn

from typing import Tuple


#############
# Functions #
#############

def psnr(x: torch.Tensor, y: torch.Tensor, dim: Tuple[int, ...] = (), keepdim: bool = False, value_range: float = 1., epsilon: float = 1e-8) -> torch.Tensor:
def psnr(
x: torch.Tensor,
y: torch.Tensor,
dim: Tuple[int, ...] = (),
keepdim: bool = False,
value_range: float = 1.,
epsilon: float = 1e-8,
) -> torch.Tensor:
r"""Returns the PSNR between `x` and `y`.
Args:
x: input tensor
y: target tensor
dim: dimension(s) along which to average
keepdim: whether the output tensor has `dim` retained or not
value_range: value range of the inputs (usually 1. or 255)
epsilon: numerical stability term
x: An input tensor.
y: A target tensor.
dim: The dimension(s) along which to average.
keepdim: Whether the output tensor has `dim` retained or not.
value_range: The value range of the inputs (usually 1. or 255).
epsilon: A numerical stability term.
"""

mse = ((x - y) ** 2).mean(dim=dim, keepdim=keepdim) + epsilon
return 10 * torch.log10(value_range ** 2 / mse)


###########
# Classes #
###########

class PSNR(nn.Module):
r"""Creates a criterion that measures the PSNR between an input and a target.
r"""Creates a criterion that measures the PSNR
between an input and a target.
Args:
value_range: value range of the inputs (usually 1. or 255)
reduction: reduction type (`'mean'`, `'sum'` or `'none'`)
value_range: The value range of the inputs (usually 1. or 255).
reduction: A reduction type (`'mean'`, `'sum'` or `'none'`).
Call:
The input and target tensors should be of shape (N, ...).
Expand All @@ -57,11 +53,16 @@ def __init__(self, value_range: float = 1., reduction: str = 'mean'):
self.value_range = value_range
self.reduction = reduction

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
l = psnr(
input, target,
input,
target,
dim=tuple(range(1, input.ndimension())),
value_range=self.value_range
value_range=self.value_range,
)

if self.reduction == 'mean':
Expand Down
Loading

0 comments on commit 7a83262

Please sign in to comment.