Skip to content

Commit

Permalink
✨ Peak Signal-to-Noise ratio (PSNR)
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Oct 23, 2020
1 parent dfd66b7 commit d79e5f3
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 1 deletion.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,12 @@ cp -R spiq <path/to/project>/spiq

## Getting started

TODO
```python
import torch
import spiq

x = torch.rand(3, 3, 256, 256)
y = torch.rand(3, 3, 256, 256)

l = spiq.psnr(x, y)
```
3 changes: 3 additions & 0 deletions spiq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__version__ = '0.0.1'

from .psnr import psnr, PSNR
68 changes: 68 additions & 0 deletions spiq/psnr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
r"""Peak Signal-to-Noise Ratio (PSNR)
This module implements the PSNR in PyTorch.
Wikipedia:
https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
"""

###########
# Imports #
###########

import torch
import torch.nn as nn


#############
# Functions #
#############

def psnr(x: torch.Tensor, y: torch.Tensor, dim: tuple=(), value_range: float=1., epsilon: float=1e-8) -> torch.Tensor:
r"""Returns the PSNR between `x` and `y`.
Args:
x: input tensor
y: target tensor
dim: dimension(s) to reduce
value_range: value range of the inputs (usually 1. or 255)
epsilon: numerical stability
"""

mse = ((x - y) ** 2).mean(dim=dim) + epsilon
return 10 * torch.log10(value_range ** 2 / mse)


###########
# Classes #
###########

class PSNR(nn.Module):
r"""Creates a criterion that measures the PSNR between an input and a target.
"""

def __init__(self, value_range: float=1., reduction='mean'):
super().__init__()

self.value_range = value_range
self.reduction = reduction

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""
Args:
input: input tensor, (N, ...)
target: target tensor, (N, ...)
"""

l = psnr(
input, target,
dim=tuple(range(1, input.ndimension())),
value_range=self.value_range
)

if self.reduction == 'mean':
return l.mean()
elif self.reduction == 'sum':
return l.sum()

return l

0 comments on commit d79e5f3

Please sign in to comment.