diff --git a/CHANGELOG.md b/CHANGELOG.md index e481173d7a1..aed5276f194 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556)) -- Added `ignore_index` to to retrieval metrics ([#676](https://github.com/PyTorchLightning/metrics/pull/676)) +- Added `ignore_index` to retrieval metrics ([#676](https://github.com/PyTorchLightning/metrics/pull/676)) - Added support for multi references in `ROUGEScore` ([#680](https://github.com/PyTorchLightning/metrics/pull/680)) @@ -71,6 +71,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `SI_SDR` -> `ScaleInvariantSignalDistortionRatio` +- Renamed audio SNR metrics: ([#712](https://github.com/PyTorchLightning/metrics/pull/712)) + * `functional.snr` -> `functional.signal_distortion_ratio` + * `functional.si_snr` -> `functional.scale_invariant_signal_noise_ratio` + * `SNR` -> `SignalNoiseRatio` + * `SI_SNR` -> `ScaleInvariantSignalNoiseRatio` + + - Renamed image metrics ([#732](https://github.com/PyTorchLightning/metrics/pull/732)) * `functional.psnr` -> `functional.peak_signal_noise_ratio` * `PSNR` -> `PeakSignalNoiseRatio` diff --git a/README.md b/README.md index 83a11254556..bb0eb81bce4 100644 --- a/README.md +++ b/README.md @@ -267,8 +267,8 @@ We currently have implemented metrics within the following domains: - Audio ( [ScaleInvariantSignalDistortionRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSignalDistortionRatio), - [SI_SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-snr), - [SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#snr) + [ScaleInvariantSignalNoiseRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSignalNoiseRatio), + [SignalNoiseRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#SignalNoiseRatio) and [few more](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#audio-metrics) ) - Classification ( diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 96aae85d36c..766850d98bd 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -38,17 +38,17 @@ scale_invariant_signal_distortion_ratio [func] :noindex: -si_snr [func] -~~~~~~~~~~~~~ +scale_invariant_signal_noise_ratio [func] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: torchmetrics.functional.si_snr :noindex: -snr [func] -~~~~~~~~~~ +signal_noise_ratio [func] +~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.snr +.. autofunction:: torchmetrics.functional.signal_noise_ratio :noindex: diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 74dcdb4f46e..c684925c074 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -66,12 +66,17 @@ the metric will be computed over the ``time`` dimension. .. doctest:: >>> import torch + >>> from torchmetrics import SignalNoiseRatio + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> snr = SignalNoiseRatio() + >>> snr(preds, target) + tensor(16.1805) >>> from torchmetrics import SNR >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> snr = SNR() - >>> snr_val = snr(preds, target) - >>> snr_val + >>> snr = SignalNoiseRatio() + >>> snr(preds, target) tensor(16.1805) PESQ @@ -97,16 +102,16 @@ ScaleInvariantSignalDistortionRatio .. autoclass:: torchmetrics.ScaleInvariantSignalDistortionRatio :noindex: -SI_SNR -~~~~~~ +ScaleInvariantSignalNoiseRatio +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: torchmetrics.SI_SNR +.. autoclass:: torchmetrics.ScaleInvariantSignalNoiseRatio :noindex: -SNR -~~~ +SignalNoiseRatio +~~~~~~~~~~~~~~~~ -.. autoclass:: torchmetrics.SNR +.. autoclass:: torchmetrics.SignalNoiseRatio :noindex: STOI diff --git a/tests/audio/test_si_snr.py b/tests/audio/test_si_snr.py index 063790b3579..6a3ad256daa 100644 --- a/tests/audio/test_si_snr.py +++ b/tests/audio/test_si_snr.py @@ -21,8 +21,8 @@ from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.audio import SI_SNR -from torchmetrics.functional import si_snr +from torchmetrics.audio import ScaleInvariantSignalNoiseRatio +from torchmetrics.functional import scale_invariant_signal_noise_ratio from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) @@ -79,7 +79,7 @@ def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step): ddp, preds, target, - SI_SNR, + ScaleInvariantSignalNoiseRatio, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, ) @@ -88,12 +88,17 @@ def test_si_snr_functional(self, preds, target, sk_metric): self.run_functional_metric_test( preds, target, - si_snr, + scale_invariant_signal_noise_ratio, sk_metric, ) def test_si_snr_differentiability(self, preds, target, sk_metric): - self.run_differentiability_test(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=ScaleInvariantSignalNoiseRatio, + metric_functional=scale_invariant_signal_noise_ratio, + ) @pytest.mark.skipif( not _TORCH_GREATER_EQUAL_1_6, reason="half support of core operations on not support before pytorch v1.6" @@ -103,10 +108,15 @@ def test_si_snr_half_cpu(self, preds, target, sk_metric): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") def test_si_snr_half_gpu(self, preds, target, sk_metric): - self.run_precision_test_gpu(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr) + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=ScaleInvariantSignalNoiseRatio, + metric_functional=scale_invariant_signal_noise_ratio, + ) -def test_error_on_different_shape(metric_class=SI_SNR): +def test_error_on_different_shape(metric_class=ScaleInvariantSignalNoiseRatio): metric = metric_class() with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): metric(torch.randn(100), torch.randn(50)) diff --git a/tests/audio/test_snr.py b/tests/audio/test_snr.py index ad63f99c9b5..1ab29a63049 100644 --- a/tests/audio/test_snr.py +++ b/tests/audio/test_snr.py @@ -22,19 +22,19 @@ from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester -from torchmetrics.audio import SNR -from torchmetrics.functional import snr +from torchmetrics.audio import SignalNoiseRatio +from torchmetrics.functional import signal_noise_ratio from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) -Time = 100 +TIME = 100 Input = namedtuple("Input", ["preds", "target"]) inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, TIME), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, TIME), ) @@ -86,7 +86,7 @@ def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step): ddp, preds, target, - SNR, + SignalNoiseRatio, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, metric_args=dict(zero_mean=zero_mean), @@ -96,14 +96,18 @@ def test_snr_functional(self, preds, target, sk_metric, zero_mean): self.run_functional_metric_test( preds, target, - snr, + signal_noise_ratio, sk_metric, metric_args=dict(zero_mean=zero_mean), ) def test_snr_differentiability(self, preds, target, sk_metric, zero_mean): self.run_differentiability_test( - preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={"zero_mean": zero_mean} + preds=preds, + target=target, + metric_module=SignalNoiseRatio, + metric_functional=signal_noise_ratio, + metric_args={"zero_mean": zero_mean}, ) @pytest.mark.skipif( @@ -115,11 +119,15 @@ def test_snr_half_cpu(self, preds, target, sk_metric, zero_mean): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean): self.run_precision_test_gpu( - preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={"zero_mean": zero_mean} + preds=preds, + target=target, + metric_module=SignalNoiseRatio, + metric_functional=signal_noise_ratio, + metric_args={"zero_mean": zero_mean}, ) -def test_error_on_different_shape(metric_class=SNR): +def test_error_on_different_shape(metric_class=SignalNoiseRatio): metric = metric_class() with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): metric(torch.randn(100), torch.randn(50)) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index d53d3481a19..93d178aecd9 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -20,7 +20,9 @@ SI_SNR, SNR, ScaleInvariantSignalDistortionRatio, + ScaleInvariantSignalNoiseRatio, SignalDistortionRatio, + SignalNoiseRatio, ) from torchmetrics.classification import ( # noqa: E402, F401 AUC, @@ -154,6 +156,8 @@ "ScaleInvariantSignalDistortionRatio", "SI_SDR", "SI_SNR", + "ScaleInvariantSignalNoiseRatio", + "SignalNoiseRatio", "SNR", "SpearmanCorrcoef", "SpearmanCorrCoef", diff --git a/torchmetrics/audio/__init__.py b/torchmetrics/audio/__init__.py index 092cb1bd60b..8dfefb1abc0 100644 --- a/torchmetrics/audio/__init__.py +++ b/torchmetrics/audio/__init__.py @@ -15,4 +15,4 @@ from torchmetrics.audio.sdr import SDR, ScaleInvariantSignalDistortionRatio, SignalDistortionRatio # noqa: F401 from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401 from torchmetrics.audio.si_snr import SI_SNR # noqa: F401 -from torchmetrics.audio.snr import SNR # noqa: F401 +from torchmetrics.audio.snr import SNR, ScaleInvariantSignalNoiseRatio, SignalNoiseRatio # noqa: F401 diff --git a/torchmetrics/audio/si_snr.py b/torchmetrics/audio/si_snr.py index 82a30bbabe2..a88ad21fc8c 100644 --- a/torchmetrics/audio/si_snr.py +++ b/torchmetrics/audio/si_snr.py @@ -12,61 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional +from warnings import warn -from torch import Tensor, tensor +from torch import Tensor -from torchmetrics.functional.audio.si_snr import si_snr -from torchmetrics.metric import Metric +from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio -class SI_SNR(Metric): +class SI_SNR(ScaleInvariantSignalNoiseRatio): """Scale-invariant signal-to-noise ratio (SI-SNR). - Forward accepts - - - ``preds``: ``shape [...,time]`` - - ``target``: ``shape [...,time]`` - - Args: - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. - - Raises: - TypeError: - if target and preds have a different shape - - Returns: - average si-snr value + .. deprecated:: v0.7 + Use :class:`torchmetrics.ScaleInvariantSignalNoiseRatio`. Will be removed in v0.8. Example: >>> import torch - >>> from torchmetrics import SI_SNR - >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) - >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) >>> si_snr = SI_SNR() - >>> si_snr_val = si_snr(preds, target) - >>> si_snr_val + >>> si_snr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0])) tensor(15.0918) - - References: - [1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech - Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp. - 696-700, doi: 10.1109/ICASSP.2018.8462116. """ - is_differentiable = True - sum_si_snr: Tensor - total: Tensor - higher_is_better = True - def __init__( self, compute_on_step: bool = True, @@ -74,28 +39,8 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, ) -> None: - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, + warn( + "`SI_SNR` was renamed to `ScaleInvariantSignalNoiseRatio` in v0.7 and it will be removed in v0.8", + DeprecationWarning, ) - - self.add_state("sum_si_snr", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - si_snr_batch = si_snr(preds=preds, target=target) - - self.sum_si_snr += si_snr_batch.sum() - self.total += si_snr_batch.numel() - - def compute(self) -> Tensor: - """Computes average SI-SNR.""" - return self.sum_si_snr / self.total + super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) diff --git a/torchmetrics/audio/snr.py b/torchmetrics/audio/snr.py index e42078e7b69..06280ab728a 100644 --- a/torchmetrics/audio/snr.py +++ b/torchmetrics/audio/snr.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Optional +from warnings import warn from torch import Tensor, tensor -from torchmetrics.functional.audio.snr import snr +from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, snr from torchmetrics.metric import Metric -class SNR(Metric): +class SignalNoiseRatio(Metric): r"""Signal-to-noise ratio (SNR_): .. math:: @@ -57,12 +58,11 @@ class SNR(Metric): Example: >>> import torch - >>> from torchmetrics import SNR + >>> from torchmetrics import SignalNoiseRatio >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> snr = SNR() - >>> snr_val = snr(preds, target) - >>> snr_val + >>> snr = SignalNoiseRatio() + >>> snr(preds, target) tensor(16.1805) References: @@ -109,3 +109,110 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore def compute(self) -> Tensor: """Computes average SNR.""" return self.sum_snr / self.total + + +class SNR(SignalNoiseRatio): + r"""Signal-to-noise ratio (SNR_): + + .. deprecated:: v0.7 + Use :class:`torchmetrics.SignalNoiseRatio`. Will be removed in v0.8. + + Example: + >>> import torch + >>> snr = SNR() + >>> snr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0])) + tensor(16.1805) + + """ + + def __init__( + self, + zero_mean: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, + ) -> None: + warn("`SNR` was renamed to `SignalNoiseRatio` in v0.7 and it will be removed in v0.8", DeprecationWarning) + super().__init__(zero_mean, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) + + +class ScaleInvariantSignalNoiseRatio(Metric): + """Scale-invariant signal-to-noise ratio (SI-SNR). + + Forward accepts + + - ``preds``: ``shape [...,time]`` + - ``target``: ``shape [...,time]`` + + Args: + compute_on_step: + Forward only calls ``update()`` and returns None if this is set to False. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. + + Raises: + TypeError: + if target and preds have a different shape + + Returns: + average si-snr value + + Example: + >>> import torch + >>> from torchmetrics import ScaleInvariantSignalNoiseRatio + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> si_snr = ScaleInvariantSignalNoiseRatio() + >>> si_snr(preds, target) + tensor(15.0918) + + References: + [1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech + Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp. + 696-700, doi: 10.1109/ICASSP.2018.8462116. + """ + + is_differentiable = True + sum_si_snr: Tensor + total: Tensor + higher_is_better = True + + def __init__( + self, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, + ) -> None: + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("sum_si_snr", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + si_snr_batch = scale_invariant_signal_noise_ratio(preds=preds, target=target) + + self.sum_si_snr += si_snr_batch.sum() + self.total += si_snr_batch.numel() + + def compute(self) -> Tensor: + """Computes average SI-SNR.""" + return self.sum_si_snr / self.total diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 3ba382bbd9c..006c77732a2 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -15,7 +15,7 @@ from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, sdr, signal_distortion_ratio from torchmetrics.functional.audio.si_sdr import si_sdr from torchmetrics.functional.audio.si_snr import si_snr -from torchmetrics.functional.audio.snr import snr +from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio, snr from torchmetrics.functional.classification.accuracy import accuracy from torchmetrics.functional.classification.auc import auc from torchmetrics.functional.classification.auroc import auroc @@ -133,7 +133,9 @@ "si_sdr", "scale_invariant_signal_distortion_ratio", "si_snr", + "scale_invariant_signal_noise_ratio", "snr", + "signal_noise_ratio", "spearman_corrcoef", "specificity", "squad", diff --git a/torchmetrics/functional/audio/__init__.py b/torchmetrics/functional/audio/__init__.py index db7f29ecbd3..3cd5c6b73bd 100644 --- a/torchmetrics/functional/audio/__init__.py +++ b/torchmetrics/functional/audio/__init__.py @@ -19,4 +19,4 @@ ) from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 -from torchmetrics.functional.audio.snr import snr # noqa: F401 +from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio, snr # noqa: F401 diff --git a/torchmetrics/functional/audio/si_snr.py b/torchmetrics/functional/audio/si_snr.py index a967ddc306c..65acdd8f59c 100644 --- a/torchmetrics/functional/audio/si_snr.py +++ b/torchmetrics/functional/audio/si_snr.py @@ -11,36 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from warnings import warn + from torch import Tensor -from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio +from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio def si_snr(preds: Tensor, target: Tensor) -> Tensor: """Scale-invariant signal-to-noise ratio (SI-SNR). - Args: - preds: - shape ``[...,time]`` - target: - shape ``[...,time]`` - - Returns: - si-snr value of shape [...] + .. deprecated:: v0.7 + Use :func:`torchmetrics.functional.scale_invariant_signal_noise_ratio`. Will be removed in v0.8. Example: >>> import torch - >>> from torchmetrics.functional.audio import si_snr - >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) - >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> si_snr_val = si_snr(preds, target) - >>> si_snr_val + >>> si_snr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0])) tensor(15.0918) - - References: - [1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech - Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp. - 696-700, doi: 10.1109/ICASSP.2018.8462116. """ - - return scale_invariant_signal_distortion_ratio(target=target, preds=preds, zero_mean=True) + warn( + "`si_snr` was renamed to `scale_invariant_signal_noise_ratio` in v0.7 and it will be removed in v0.8", + DeprecationWarning, + ) + return scale_invariant_signal_noise_ratio(preds, target) diff --git a/torchmetrics/functional/audio/snr.py b/torchmetrics/functional/audio/snr.py index dc1fc42e9c1..8a199c07cbf 100644 --- a/torchmetrics/functional/audio/snr.py +++ b/torchmetrics/functional/audio/snr.py @@ -11,13 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from warnings import warn + import torch from torch import Tensor +from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio from torchmetrics.utilities.checks import _check_same_shape -def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: +def signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: r"""Signal-to-noise ratio (SNR_): .. math:: @@ -39,11 +42,10 @@ def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: snr value of shape [...] Example: - >>> from torchmetrics.functional.audio import snr + >>> from torchmetrics.functional.audio import signal_noise_ratio >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> snr_val = snr(preds, target) - >>> snr_val + >>> signal_noise_ratio(preds, target) tensor(16.1805) References: @@ -64,3 +66,46 @@ def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: snr_value = 10 * torch.log10(snr_value) return snr_value + + +def snr(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: + r"""Signal-to-noise ratio (SNR_) + + .. deprecated:: v0.7 + Use :func:`torchmetrics.functional.signal_noise_ratio`. Will be removed in v0.8. + + Example: + >>> snr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0])) + tensor(16.1805) + + """ + warn("`snr` was renamed to `signal_noise_ratio` in v0.7 and it will be removed in v0.8", DeprecationWarning) + return signal_noise_ratio(preds, target, zero_mean) + + +def scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor: + """Scale-invariant signal-to-noise ratio (SI-SNR). + + Args: + preds: + shape ``[...,time]`` + target: + shape ``[...,time]`` + + Returns: + si-snr value of shape [...] + + Example: + >>> import torch + >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> scale_invariant_signal_noise_ratio(preds, target) + tensor(15.0918) + + References: + [1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech + Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp. + 696-700, doi: 10.1109/ICASSP.2018.8462116. + """ + return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=True)