Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename peak_signal_noise_ratio #732

Merged
merged 28 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6e95772
psnr --> peak_signal_noise_ratio
AIMistakes Jan 8, 2022
fcf2f7c
Revert "psnr --> peak_signal_noise_ratio"
Borda Jan 8, 2022
366f62d
psnr --> peak_signal_noise_ratio
Borda Jan 8, 2022
9bcbdf4
Refactor: SDR & SI_SDR (#711)
Borda Jan 8, 2022
15b2a42
changelog
SkafteNicki Jan 8, 2022
fadefe4
PSNR --> PeakSignalNoiseRatio
AIMistakes Jan 8, 2022
a407c6e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2022
48d51c2
Refactor: SNR & SI_SNR (#712)
Borda Jan 8, 2022
2e28fcc
Fix metrics documentation (#728)
AndresAlgaba Jan 8, 2022
8c7c062
fbeta --> fbeta_score
AIMistakes Jan 9, 2022
1383db6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2022
8f0a15e
Revert "fbeta --> fbeta_score FBeta --> FBetaScore"
Borda Jan 10, 2022
726f9bd
rename f1 score (#731)
cuent Jan 9, 2022
b717fb8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2022
9390fc6
PSNR backward compatibility.
AIMistakes Jan 10, 2022
fa27cec
Merge branch 'master' into Fariborzzz/master
Borda Jan 11, 2022
14470ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2022
d6f4bfc
...
Borda Jan 11, 2022
18a7af6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2022
9f443e5
typo
Borda Jan 12, 2022
bd48431
Merge branch 'master' into master
Borda Jan 12, 2022
29070fe
fix
Borda Jan 12, 2022
6dc40cc
Merge branch 'master' of https://github.com/Fariborzzz/metrics into F…
Borda Jan 12, 2022
327d563
fix
Borda Jan 12, 2022
1c0fbbb
Merge branch 'master' into master
Borda Jan 12, 2022
f53c3b5
Merge branch 'master' into master
SkafteNicki Jan 12, 2022
e0a411e
docs
Borda Jan 12, 2022
ff67ba8
Merge branch 'master' into master
Borda Jan 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `torchmetrics.Hinge` -> `torchmetrics.HingeLoss`


- Renamed image metrics ([#732](https://github.com/PyTorchLightning/metrics/pull/732))
* `functional.psnr` -> `functional.peak_signal_noise_ratio`
* `PSNR` -> `PeakSignalNoiseRatio`


### Removed

- Removed `embedding_similarity` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ the following limitations:
* Some metrics does not work at all in half precision on CPU. We have explicitly stated this in their docstring,
but they are also listed below:

- :ref:`references/modules:PSNR` and :ref:`references/functional:psnr [func]`
- :ref:`references/modules:PSNR` and :ref:`references/functional:peak_signal_noise_ratio [func]`
- :ref:`references/modules:SSIM` and :ref:`references/functional:ssim [func]`
- :ref:`references/modules:KLDivergence` and :ref:`references/functional:kl_divergence [func]`

Expand Down
7 changes: 3 additions & 4 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -240,18 +240,17 @@ image_gradients [func]
.. autofunction:: torchmetrics.functional.image_gradients
:noindex:


multiscale_structural_similarity_index_measure [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.multiscale_structural_similarity_index_measure
:noindex:


psnr [func]
~~~~~~~~~~~
peak_signal_noise_ratio [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.psnr
.. autofunction:: torchmetrics.functional.peak_signal_noise_ratio
:noindex:


Expand Down
6 changes: 3 additions & 3 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,10 @@ MultiScaleStructuralSimilarityIndexMeasure
.. autoclass:: torchmetrics.MultiScaleStructuralSimilarityIndexMeasure
:noindex:

PSNR
~~~~
PeakSignalNoiseRatio
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.PSNR
.. autoclass:: torchmetrics.PeakSignalNoiseRatio
:noindex:

SSIM
Expand Down
26 changes: 17 additions & 9 deletions tests/image/test_psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
import numpy as np
import pytest
import torch
from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import peak_signal_noise_ratio as skimage_peak_signal_noise_ratio

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional import psnr
from torchmetrics.image import PSNR
from torchmetrics.functional import peak_signal_noise_ratio, psnr
from torchmetrics.image import PSNR, PeakSignalNoiseRatio

seed_all(42)

Expand Down Expand Up @@ -64,7 +64,7 @@ def _sk_psnr(preds, target, data_range, reduction, dim):
np_reduce_map = {"elementwise_mean": np.mean, "none": np.array, "sum": np.sum}
return np_reduce_map[reduction](
[
peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range)
skimage_peak_signal_noise_ratio(sk_target, sk_preds, data_range=data_range)
for sk_target, sk_preds in zip(sk_target_lists, sk_preds_lists)
]
)
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduc
self.run_functional_metric_test(
preds,
target,
psnr,
peak_signal_noise_ratio,
partial(sk_metric, data_range=data_range, reduction=reduction, dim=dim),
metric_args=_args,
)
Expand All @@ -121,13 +121,21 @@ def test_psnr_functional(self, preds, target, sk_metric, data_range, base, reduc
@pytest.mark.xfail(reason="PSNR metric does not support cpu + half precision")
def test_psnr_half_cpu(self, preds, target, data_range, reduction, dim, base, sk_metric):
self.run_precision_test_cpu(
preds, target, PSNR, psnr, {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}
preds,
target,
PSNR,
peak_signal_noise_ratio,
{"data_range": data_range, "base": base, "reduction": reduction, "dim": dim},
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_psnr_half_gpu(self, preds, target, data_range, reduction, dim, base, sk_metric):
self.run_precision_test_gpu(
preds, target, PSNR, psnr, {"data_range": data_range, "base": base, "reduction": reduction, "dim": dim}
preds,
target,
PSNR,
peak_signal_noise_ratio,
{"data_range": data_range, "base": base, "reduction": reduction, "dim": dim},
)


Expand All @@ -138,12 +146,12 @@ def test_reduction_for_dim_none(reduction):
PSNR(reduction=reduction, dim=None)

with pytest.warns(UserWarning, match=match):
psnr(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None)
peak_signal_noise_ratio(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None)


def test_missing_data_range():
with pytest.raises(ValueError):
PSNR(data_range=None, dim=0)

with pytest.raises(ValueError):
psnr(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0)
peak_signal_noise_ratio(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0)
8 changes: 7 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@
Specificity,
StatScores,
)
from torchmetrics.image import PSNR, SSIM, MultiScaleStructuralSimilarityIndexMeasure # noqa: E402
from torchmetrics.image import ( # noqa: E402
PSNR,
SSIM,
MultiScaleStructuralSimilarityIndexMeasure,
PeakSignalNoiseRatio,
)
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.metric_collections import MetricCollection # noqa: E402
from torchmetrics.regression import ( # noqa: E402
Expand Down Expand Up @@ -146,6 +151,7 @@
"Precision",
"PrecisionRecallCurve",
"PSNR",
"PeakSignalNoiseRatio",
"R2Score",
"Recall",
"RetrievalFallOut",
Expand Down
3 changes: 2 additions & 1 deletion torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from torchmetrics.functional.classification.stat_scores import stat_scores
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.ms_ssim import multiscale_structural_similarity_index_measure
from torchmetrics.functional.image.psnr import psnr
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio, psnr
from torchmetrics.functional.image.ssim import ssim
from torchmetrics.functional.pairwise.cosine import pairwise_cosine_similarity
from torchmetrics.functional.pairwise.euclidean import pairwise_euclidean_distance
Expand Down Expand Up @@ -120,6 +120,7 @@
"precision",
"precision_recall",
"precision_recall_curve",
"peak_signal_noise_ratio",
"psnr",
"r2_score",
"recall",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.

from torchmetrics.functional.image.gradients import image_gradients # noqa: F401
from torchmetrics.functional.image.psnr import psnr # noqa: F401
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio, psnr # noqa: F401
66 changes: 66 additions & 0 deletions torchmetrics/functional/image/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from warnings import warn

import torch
from torch import Tensor, tensor
Expand Down Expand Up @@ -91,6 +92,65 @@ def _psnr_update(
return sum_squared_error, n_obs


def peak_signal_noise_ratio(
AIMistakes marked this conversation as resolved.
Show resolved Hide resolved
preds: Tensor,
target: Tensor,
data_range: Optional[float] = None,
base: float = 10.0,
reduction: str = "elementwise_mean",
dim: Optional[Union[int, Tuple[int, ...]]] = None,
) -> Tensor:
"""Computes the peak signal-to-noise ratio.

Args:
preds: estimated signal
target: groun truth signal
data_range:
the range of the data. If None, it is determined from the data (max - min). ``data_range`` must be given
when ``dim`` is not None.
base: a base of a logarithm to use (default: 10)
reduction: a method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

dim:
Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is
None meaning scores will be reduced across all dimensions.
Return:
Tensor with PSNR score

Raises:
ValueError:
If ``dim`` is not ``None`` and ``data_range`` is not provided.

Example:
>>> from torchmetrics.functional import peak_signal_noise_ratio
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> peak_signal_noise_ratio(pred, target)
tensor(2.5527)

.. note::
Half precision is only support on GPU for this metric
"""
if dim is None and reduction != "elementwise_mean":
rank_zero_warn(f"The `reduction={reduction}` will not have any effect when `dim` is None.")

if data_range is None:
if dim is not None:
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate
# `data_range` in the future.
raise ValueError("The `data_range` must be given when `dim` is not None.")

data_range = target.max() - target.min()
else:
data_range = tensor(float(data_range))
sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim)
return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction)


def psnr(
preds: Tensor,
target: Tensor,
Expand All @@ -101,6 +161,8 @@ def psnr(
) -> Tensor:
"""Computes the peak signal-to-noise ratio.

. deprecated:: v0.7
Use :function:torchmetrics.functional.psnr. Will be removed in v0.8.
Args:
preds: estimated signal
target: groun truth signal
Expand Down Expand Up @@ -147,4 +209,8 @@ def psnr(
else:
data_range = tensor(float(data_range))
sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim)
warn(
"`psnr` was renamed to `peak_signal_noise_ratio` in v0.7 and it will be removed in v0.8",
DeprecationWarning,
)
return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction)
2 changes: 1 addition & 1 deletion torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.image.ms_ssim import multiscale_structural_similarity_index_measure # noqa: F401
from torchmetrics.functional.image.psnr import psnr # noqa: F401
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio, psnr # noqa: F401
from torchmetrics.functional.image.ssim import ssim # noqa: F401
from torchmetrics.functional.regression.cosine_similarity import cosine_similarity # noqa: F401
from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.image.psnr import PSNR # noqa: F401
from torchmetrics.image.psnr import PSNR, PeakSignalNoiseRatio # noqa: F401
from torchmetrics.image.ssim import SSIM, MultiScaleStructuralSimilarityIndexMeasure # noqa: F401
82 changes: 78 additions & 4 deletions torchmetrics/image/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Tuple, Union
from warnings import warn

import torch
from torch import Tensor, tensor
Expand All @@ -21,7 +22,7 @@
from torchmetrics.utilities import rank_zero_warn


class PSNR(Metric):
class PeakSignalNoiseRatio(Metric):
AIMistakes marked this conversation as resolved.
Show resolved Hide resolved
r"""
Computes `Computes Peak Signal-to-Noise Ratio`_ (PSNR):

Expand Down Expand Up @@ -56,11 +57,11 @@ class PSNR(Metric):
If ``dim`` is not ``None`` and ``data_range`` is not given.

Example:
>>> from torchmetrics import PSNR
>>> psnr = PSNR()
>>> from torchmetrics import PeakSignalNoiseRatio
>>> peak_signal_noise_ratio = PSNR()
>>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> psnr(preds, target)
>>> peak_signal_noise_ratio(preds, target)
tensor(2.5527)

.. note::
Expand Down Expand Up @@ -146,3 +147,76 @@ def compute(self) -> Tensor:
sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error])
total = torch.cat([values.flatten() for values in self.total])
return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction)


class PSNR(PeakSignalNoiseRatio):
"""Peak Signal Noise Ratio (PSNR).

.. deprecated:: v0.7
Use :class:`torchmetrics.PeakSignalNoiseRatio`. Will be removed in v0.8.

.. math:: \text{PSNR}(I, J) = 10 * \\log_{10} \\left(\frac{\\max(I)^2}{\text{MSE}(I, J)}\right)

Where :math:`\text{MSE}` denotes the `mean-squared-error`_ function.

Args:
data_range:
the range of the data. If None, it is determined from the data (max - min).
The ``data_range`` must be given when ``dim`` is not None.
base: a base of a logarithm to use.
reduction: a method to reduce metric score over labels.

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

dim:
Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is
None meaning scores will be reduced across all dimensions and all batches.
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.

Raises:
ValueError:
If ``dim`` is not ``None`` and ``data_range`` is not given.

Example:
>>> from torchmetrics import PSNR
>>> peak_signal_noise_ratio = PSNR()
>>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> psnr(preds, target)
tensor(2.5527)

.. note::
Half precision is only support on GPU for this metric
"""

min_target: Tensor
max_target: Tensor
higher_is_better = False

def __init__(
self,
data_range: Optional[float] = None,
base: float = 10.0,
reduction: str = "elementwise_mean",
dim: Optional[Union[int, Tuple[int, ...]]] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
) -> None:
warn(
"`PSNR` was renamed to `PeakSignalNoiseRatio` in v0.7 and it will be removed in v0.8",
DeprecationWarning,
)
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)