From d79e5f32981350475c6446fcba9610fd631ecc07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Wed, 23 Sep 2020 21:05:11 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Peak=20Signal-to-Noise=20ratio=20(P?= =?UTF-8?q?SNR)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 10 ++++++- spiq/__init__.py | 3 +++ spiq/psnr.py | 68 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 spiq/__init__.py create mode 100644 spiq/psnr.py diff --git a/README.md b/README.md index 70f1f0e..308c6e0 100644 --- a/README.md +++ b/README.md @@ -24,4 +24,12 @@ cp -R spiq /spiq ## Getting started -TODO +```python +import torch +import spiq + +x = torch.rand(3, 3, 256, 256) +y = torch.rand(3, 3, 256, 256) + +l = spiq.psnr(x, y) +``` diff --git a/spiq/__init__.py b/spiq/__init__.py new file mode 100644 index 0000000..f091b35 --- /dev/null +++ b/spiq/__init__.py @@ -0,0 +1,3 @@ +__version__ = '0.0.1' + +from .psnr import psnr, PSNR diff --git a/spiq/psnr.py b/spiq/psnr.py new file mode 100644 index 0000000..d0c5b0b --- /dev/null +++ b/spiq/psnr.py @@ -0,0 +1,68 @@ +r"""Peak Signal-to-Noise Ratio (PSNR) + +This module implements the PSNR in PyTorch. + +Wikipedia: + https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio +""" + +########### +# Imports # +########### + +import torch +import torch.nn as nn + + +############# +# Functions # +############# + +def psnr(x: torch.Tensor, y: torch.Tensor, dim: tuple=(), value_range: float=1., epsilon: float=1e-8) -> torch.Tensor: + r"""Returns the PSNR between `x` and `y`. + + Args: + x: input tensor + y: target tensor + dim: dimension(s) to reduce + value_range: value range of the inputs (usually 1. or 255) + epsilon: numerical stability + """ + + mse = ((x - y) ** 2).mean(dim=dim) + epsilon + return 10 * torch.log10(value_range ** 2 / mse) + + +########### +# Classes # +########### + +class PSNR(nn.Module): + r"""Creates a criterion that measures the PSNR between an input and a target. + """ + + def __init__(self, value_range: float=1., reduction='mean'): + super().__init__() + + self.value_range = value_range + self.reduction = reduction + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + r""" + Args: + input: input tensor, (N, ...) + target: target tensor, (N, ...) + """ + + l = psnr( + input, target, + dim=tuple(range(1, input.ndimension())), + value_range=self.value_range + ) + + if self.reduction == 'mean': + return l.mean() + elif self.reduction == 'sum': + return l.sum() + + return l