diff --git a/CHANGELOG.md b/CHANGELOG.md index e77fd63661783..209409299548b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added a PSNR metric: peak signal-to-noise ratio ([#2483](https://github.com/PyTorchLightning/pytorch-lightning/pull/2483)) ### Changed diff --git a/environment.yml b/environment.yml index 068d5bef11c08..25d1a0d6bdb51 100644 --- a/environment.yml +++ b/environment.yml @@ -29,6 +29,7 @@ dependencies: - autopep8 - twine==1.13.0 - pillow<7.0.0 + - scikit-image # Optional - scipy>=0.13.3 diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 34e8ece3e6e10..e1409387f8e71 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -5,6 +5,7 @@ RMSE, MAE, RMSLE, + PSNR ) from pytorch_lightning.metrics.classification import ( Accuracy, @@ -50,6 +51,7 @@ 'MSE', 'RMSE', 'MAE', - 'RMSLE' + 'RMSLE', + 'PSNR' ] __all__ = __regression_metrics + __classification_metrics + ['SklearnMetric'] diff --git a/pytorch_lightning/metrics/functional/regression.py b/pytorch_lightning/metrics/functional/regression.py new file mode 100644 index 0000000000000..db2e964d2bc30 --- /dev/null +++ b/pytorch_lightning/metrics/functional/regression.py @@ -0,0 +1,40 @@ +import torch +from torch.nn import functional as F + +from pytorch_lightning.metrics.functional.reduction import reduce + + +def psnr( + pred: torch.Tensor, + target: torch.Tensor, + data_range: float = None, + base: float = 10.0, + reduction: str = 'elementwise_mean' +) -> torch.Tensor: + """ + Computes the peak signal-to-noise ratio metric + + Args: + pred: estimated signal + target: groun truth signal + data_range: the range of the data. If None, it is determined from the data (max - min). + base: a base of a logarithm to use (default: 10) + reduction: method for reducing psnr (default: takes the mean) + + Example: + + >>> from pytorch_lightning.metrics.regression import PSNR + >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + >>> metric = PSNR() + >>> metric(pred, target) + tensor(2.5527) + """ + + if data_range is None: + data_range = max(target.max() - target.min(), pred.max() - pred.min()) + else: + data_range = torch.tensor(float(data_range)) + mse = F.mse_loss(pred.view(-1), target.view(-1), reduction=reduction) + psnr_base_e = 2 * torch.log(data_range) - torch.log(mse) + return psnr_base_e * (10 / torch.log(torch.tensor(base))) diff --git a/pytorch_lightning/metrics/regression.py b/pytorch_lightning/metrics/regression.py index 964d88cd9604f..7d7db14223680 100644 --- a/pytorch_lightning/metrics/regression.py +++ b/pytorch_lightning/metrics/regression.py @@ -1,8 +1,11 @@ import torch.nn.functional as F import torch from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.functional.regression import ( + psnr, +) -__all__ = ['MSE', 'RMSE', 'MAE', 'RMSLE'] +__all__ = ['MSE', 'RMSE', 'MAE', 'RMSLE', 'PSNR'] class MSE(Metric): @@ -187,3 +190,33 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: A Tensor with the rmsle loss. """ return F.mse_loss(torch.log(pred + 1), torch.log(target + 1), self.reduction) + + +class PSNR(Metric): + """ + Computes the peak signal-to-noise ratio metric + """ + + def __init__(self, data_range: float = None, base: int = 10, reduction: str = 'elementwise_mean'): + """ + Args: + data_range: the range of the data. If None, it is determined from the data (max - min). + base: a base of a logarithm to use (default: 10) + reduction: method for reducing psnr (default: takes the mean) + + + Example: + + >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + >>> metric = PSNR() + >>> metric(pred, target) + tensor(2.5527) + """ + super().__init__(name='psnr') + self.data_range = data_range + self.base = float(base) + self.reduction = reduction + + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + return psnr(pred, target, self.data_range, self.base, self.reduction) diff --git a/requirements/base.txt b/requirements/base.txt index bacc868dada85..4072df9466dc7 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,6 +1,6 @@ # the default package dependencies -numpy>=1.15 # because some BLAS compilation issues +numpy>=1.16.4 torch>=1.3 tensorboard>=1.14 future>=0.17.1 # required for builtins in setup.py diff --git a/requirements/test.txt b/requirements/test.txt index f2f595edf718b..bd31596c21f69 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -7,7 +7,7 @@ flake8 flake8-black check-manifest twine==1.13.0 - +scikit-image black==19.10b0 pre-commit>=1.0 diff --git a/tests/metrics/functional/test_regression.py b/tests/metrics/functional/test_regression.py new file mode 100644 index 0000000000000..824d4adde1f45 --- /dev/null +++ b/tests/metrics/functional/test_regression.py @@ -0,0 +1,37 @@ +import pytest +import torch + +from skimage.metrics import peak_signal_noise_ratio as ski_psnr +from pytorch_lightning.metrics.functional.regression import psnr + + +@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ + pytest.param(ski_psnr, psnr, id='peak_signal_noise_ratio') +]) +def test_psnr_against_sklearn(sklearn_metric, torch_metric): + """Compare PL metrics to sklearn version.""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + pred = torch.randint(10, (500,), device=device, dtype=torch.double) + target = torch.randint(10, (500,), device=device, dtype=torch.double) + assert torch.allclose( + torch.tensor(sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy(), + data_range=10), dtype=torch.double, device=device), + torch_metric(pred, target, data_range=10)) + + pred = torch.randint(5, (500,), device=device, dtype=torch.double) + target = torch.randint(10, (500,), device=device, dtype=torch.double) + assert torch.allclose( + torch.tensor(sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy(), + data_range=10), dtype=torch.double, device=device), + torch_metric(pred, target, data_range=10)) + + pred = torch.randint(10, (500,), device=device, dtype=torch.double) + target = torch.randint(5, (500,), device=device, dtype=torch.double) + assert torch.allclose( + torch.tensor(sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy(), + data_range=5), dtype=torch.double, device=device), + torch_metric(pred, target, data_range=5)) diff --git a/tests/metrics/test_regression.py b/tests/metrics/test_regression.py index 5675051cff090..702e66fcac373 100644 --- a/tests/metrics/test_regression.py +++ b/tests/metrics/test_regression.py @@ -1,8 +1,10 @@ import pytest import torch +from skimage.metrics import peak_signal_noise_ratio as ski_psnr +import numpy as np from pytorch_lightning.metrics.regression import ( - MAE, MSE, RMSE, RMSLE + MAE, MSE, RMSE, RMSLE, PSNR ) @@ -64,3 +66,45 @@ def test_rmsle(pred, target, exp): assert isinstance(score, torch.Tensor) assert pytest.approx(score.item(), rel=1e-3) == exp + + +@pytest.mark.parametrize(['pred', 'target', 'exp'], [ + pytest.param( + [0., 1., 2., 3.], + [0., 1., 2., 2.], + ski_psnr(np.array([0., 1., 2., 3.]), np.array([0., 1., 2., 2.]), data_range=3) + ), + pytest.param( + [4., 3., 2., 1.], + [1., 4., 3., 2.], + ski_psnr(np.array([4., 3., 2., 1.]), np.array([1., 4., 3., 2.]), data_range=3) + ) +]) +def test_psnr(pred, target, exp): + psnr = PSNR() + assert psnr.name == 'psnr' + score = psnr(pred=torch.tensor(pred), + target=torch.tensor(target)) + assert isinstance(score, torch.Tensor) + assert pytest.approx(score.item(), rel=1e-3) == exp + + +@pytest.mark.parametrize(['pred', 'target', 'exp'], [ + pytest.param( + [0., 1., 2., 3.], + [0., 1., 2., 2.], + ski_psnr(np.array([0., 1., 2., 3.]), np.array([0., 1., 2., 2.]), data_range=4) * np.log(10) + ), + pytest.param( + [4., 3., 2., 1.], + [1., 4., 3., 2.], + ski_psnr(np.array([4., 3., 2., 1.]), np.array([1., 4., 3., 2.]), data_range=4) * np.log(10) + ) +]) +def test_psnr_base_e_wider_range(pred, target, exp): + psnr = PSNR(data_range=4, base=2.718281828459045) + assert psnr.name == 'psnr' + score = psnr(pred=torch.tensor(pred), + target=torch.tensor(target)) + assert isinstance(score, torch.Tensor) + assert pytest.approx(score.item(), rel=1e-3) == exp