diff --git a/CHANGELOG.md b/CHANGELOG.md index 64dce444534..2bb38b37cf4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `pairwise_manhatten_distance` +- Added Short Term Objective Intelligibility (`STOI`) ([#353](https://github.com/PyTorchLightning/metrics/issues/353)) + + ### Changed - `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 4a9ddc853eb..2b4708ecea0 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -45,6 +45,13 @@ snr [func] :noindex: +stoi [func] +~~~~~~~~~~~ + +.. autofunction:: torchmetrics.functional.stoi + :noindex: + + ********************** Classification Metrics ********************** diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 847ebef3ca9..a6566018384 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -103,6 +103,12 @@ SNR .. autoclass:: torchmetrics.SNR :noindex: +STOI +~~~~ + +.. autoclass:: torchmetrics.STOI + :noindex: + ********************** Classification Metrics diff --git a/requirements/audio.txt b/requirements/audio.txt index 9dbb48e2084..f0d64dfb017 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -1 +1,2 @@ pesq>=0.0.3 +pystoi diff --git a/tests/audio/test_stoi.py b/tests/audio/test_stoi.py new file mode 100644 index 00000000000..9f98bc9b5ed --- /dev/null +++ b/tests/audio/test_stoi.py @@ -0,0 +1,146 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from pystoi import stoi as stoi_backend +from torch import Tensor + +from tests.helpers import seed_all +from tests.helpers.testers import MetricTester +from torchmetrics.audio import STOI +from torchmetrics.functional import stoi +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) + +inputs_8k = Input( + preds=torch.rand(2, 3, 8000), + target=torch.rand(2, 3, 8000), +) +inputs_16k = Input( + preds=torch.rand(2, 3, 16000), + target=torch.rand(2, 3, 16000), +) + + +def stoi_original_batch(preds: Tensor, target: Tensor, fs: int, extended: bool): + # shape: preds [BATCH_SIZE, Time] , target [BATCH_SIZE, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, Time] , target [NUM_BATCHES*BATCH_SIZE, Time] + target = target.detach().cpu().numpy() + preds = preds.detach().cpu().numpy() + mss = [] + for b in range(preds.shape[0]): + pesq_val = stoi_backend(target[b, ...], preds[b, ...], fs, extended) + mss.append(pesq_val) + return torch.tensor(mss) + + +def average_metric(preds, target, metric_func): + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + return metric_func(preds, target).mean() + + +stoi_original_batch_8k_ext = partial(stoi_original_batch, fs=8000, extended=True) +stoi_original_batch_16k_ext = partial(stoi_original_batch, fs=16000, extended=True) +stoi_original_batch_8k_noext = partial(stoi_original_batch, fs=8000, extended=False) +stoi_original_batch_16k_noext = partial(stoi_original_batch, fs=16000, extended=False) + + +@pytest.mark.parametrize( + "preds, target, sk_metric, fs, extended", + [ + (inputs_8k.preds, inputs_8k.target, stoi_original_batch_8k_ext, 8000, True), + (inputs_16k.preds, inputs_16k.target, stoi_original_batch_16k_ext, 16000, True), + (inputs_8k.preds, inputs_8k.target, stoi_original_batch_8k_noext, 8000, False), + (inputs_16k.preds, inputs_16k.target, stoi_original_batch_16k_noext, 16000, False), + ], +) +class TestSTOI(MetricTester): + atol = 1e-2 + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + STOI, + sk_metric=partial(average_metric, metric_func=sk_metric), + dist_sync_on_step=dist_sync_on_step, + metric_args=dict(fs=fs, extended=extended), + ) + + def test_stoi_functional(self, preds, target, sk_metric, fs, extended): + self.run_functional_metric_test( + preds, + target, + stoi, + sk_metric, + metric_args=dict(fs=fs, extended=extended), + ) + + def test_stoi_differentiability(self, preds, target, sk_metric, fs, extended): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=STOI, + metric_functional=stoi, + metric_args=dict(fs=fs, extended=extended), + ) + + @pytest.mark.skipif( + not _TORCH_GREATER_EQUAL_1_6, reason="half support of core operations on not support before pytorch v1.6" + ) + def test_stoi_half_cpu(self, preds, target, sk_metric, fs, extended): + pytest.xfail("STOI metric does not support cpu + half precision") + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + def test_stoi_half_gpu(self, preds, target, sk_metric, fs, extended): + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=STOI, + metric_functional=partial(stoi, fs=fs, extended=extended), + metric_args=dict(fs=fs, extended=extended), + ) + + +def test_error_on_different_shape(metric_class=STOI): + metric = metric_class(16000) + with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): + metric(torch.randn(100), torch.randn(50)) + + +def test_on_real_audio(): + import os + + from scipy.io import wavfile + + current_file_dir = os.path.dirname(__file__) + + rate, ref = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech.wav")) + rate, deg = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech_bab_0dB.wav")) + assert torch.allclose( + stoi(torch.from_numpy(deg), torch.from_numpy(ref), rate).float(), + torch.tensor(0.6739177), + rtol=0.0001, + atol=1e-4, + ) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 45d52ca4d1d..3d65a6eb8bc 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -13,7 +13,7 @@ from torchmetrics import functional # noqa: E402 from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402 -from torchmetrics.audio import PESQ, PIT, SI_SDR, SI_SNR, SNR # noqa: E402 +from torchmetrics.audio import PESQ, PIT, SI_SDR, SI_SNR, SNR, STOI # noqa: E402 from torchmetrics.classification import ( # noqa: E402 AUC, AUROC, @@ -131,6 +131,7 @@ "Specificity", "SSIM", "StatScores", + "STOI", "SumMetric", "SymmetricMeanAbsolutePercentageError", "WER", diff --git a/torchmetrics/audio/__init__.py b/torchmetrics/audio/__init__.py index f25cbc40bba..fe1dd7e4901 100644 --- a/torchmetrics/audio/__init__.py +++ b/torchmetrics/audio/__init__.py @@ -16,3 +16,4 @@ from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401 from torchmetrics.audio.si_snr import SI_SNR # noqa: F401 from torchmetrics.audio.snr import SNR # noqa: F401 +from torchmetrics.audio.stoi import STOI # noqa: F401 diff --git a/torchmetrics/audio/stoi.py b/torchmetrics/audio/stoi.py new file mode 100644 index 00000000000..1c2148b9b4c --- /dev/null +++ b/torchmetrics/audio/stoi.py @@ -0,0 +1,133 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +from torch import Tensor, tensor + +from torchmetrics.functional.audio.stoi import stoi +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE + + +class STOI(Metric): + r"""STOI (Short Term Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1]. + Note that input will be moved to `cpu` to perform the metric calculation. + + Intelligibility measure which is highly correlated with the intelligibility of degraded speech signals, e.g., due + to additive noise, single/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations. + The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good + alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are + interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms, + on speech intelligibility. Description taken from [Cees Taal's website](http://www.ceestaal.nl/code/). + + .. note:: using this metrics requires you to have ``pystoi`` install. Either install as ``pip install + torchmetrics[audio]`` or ``pip install pystoi`` + + Forward accepts + + - ``preds``: ``shape [...,time]`` + - ``target``: ``shape [...,time]`` + + Args: + fs: + sampling frequency (Hz) + extended: + whether to use the extended STOI described in [4] + compute_on_step: + Forward only calls ``update()`` and returns None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. + + Returns: + average STOI value + + Raises: + ModuleNotFoundError: + If ``pystoi`` package is not installed + + Example: + >>> from torchmetrics.audio import STOI + >>> import torch + >>> g = torch.manual_seed(1) + >>> preds = torch.randn(8000) + >>> target = torch.randn(8000) + >>> stoi = STOI(8000, False) + >>> stoi(preds, target) + tensor(-0.0100) + + References: + [1] https://github.com/mpariente/pystoi + + [2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time Objective Intelligibility Measure for + Time-Frequency Weighted Noisy Speech', ICASSP 2010, Texas, Dallas. + + [3] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for Intelligibility Prediction of + Time-Frequency Weighted Noisy Speech', IEEE Transactions on Audio, Speech, and Language Processing, 2011. + + [4] J. Jensen and C. H. Taal, 'An Algorithm for Predicting the Intelligibility of Speech Masked by Modulated + Noise Maskers', IEEE Transactions on Audio, Speech and Language Processing, 2016. + + """ + sum_stoi: Tensor + total: Tensor + is_differentiable = False + higher_is_better = True + + def __init__( + self, + fs: int, + extended: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, + ) -> None: + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + if not _PYSTOI_AVAILABLE: + raise ModuleNotFoundError( + "STOI metric requires that pystoi is installed." + " Either install as `pip install torchmetrics[audio]` or `pip install pystoi`" + ) + self.fs = fs + self.extended = extended + + self.add_state("sum_stoi", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + stoi_batch = stoi(preds, target, self.fs, self.extended, False).to(self.sum_stoi.device) + + self.sum_stoi += stoi_batch.sum() + self.total += stoi_batch.numel() + + def compute(self) -> Tensor: + """Computes average STOI.""" + return self.sum_stoi / self.total diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index a8c7aec49bc..a25500a5642 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -16,6 +16,7 @@ from torchmetrics.functional.audio.si_sdr import si_sdr from torchmetrics.functional.audio.si_snr import si_snr from torchmetrics.functional.audio.snr import snr +from torchmetrics.functional.audio.stoi import stoi from torchmetrics.functional.classification.accuracy import accuracy from torchmetrics.functional.classification.auc import auc from torchmetrics.functional.classification.auroc import auroc @@ -127,6 +128,7 @@ "specificity", "ssim", "stat_scores", + "stoi", "symmetric_mean_absolute_percentage_error", "wer", ] diff --git a/torchmetrics/functional/audio/__init__.py b/torchmetrics/functional/audio/__init__.py index f21a4421dbe..678f45419db 100644 --- a/torchmetrics/functional/audio/__init__.py +++ b/torchmetrics/functional/audio/__init__.py @@ -16,3 +16,4 @@ from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401 from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401 from torchmetrics.functional.audio.snr import snr # noqa: F401 +from torchmetrics.functional.audio.stoi import stoi # noqa: F401 diff --git a/torchmetrics/functional/audio/stoi.py b/torchmetrics/functional/audio/stoi.py new file mode 100644 index 00000000000..71e36bf9c54 --- /dev/null +++ b/torchmetrics/functional/audio/stoi.py @@ -0,0 +1,105 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch + +from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE + +if _PYSTOI_AVAILABLE: + from pystoi import stoi as stoi_backend +else: + stoi_backend = None +from torch import Tensor + +from torchmetrics.utilities.checks import _check_same_shape + + +def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_same_device: bool = False) -> Tensor: + r"""STOI (Short Term Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1]. + Note that input will be moved to `cpu` to perform the metric calculation. + + Intelligibility measure which is highly correlated with the intelligibility of degraded speech signals, e.g., due + to additive noise, single/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations. + The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good + alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are + interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms, + on speech intelligibility. Description taken from [Cees Taal's website](http://www.ceestaal.nl/code/). + + .. note:: using this metrics requires you to have ``pystoi`` install. Either install as ``pip install + torchmetrics[audio]`` or ``pip install pystoi`` + + Args: + preds: + shape ``[...,time]`` + target: + shape ``[...,time]`` + fs: + sampling frequency (Hz) + extended: + whether to use the extended STOI described in [4] + keep_same_device: + whether to move the stoi value to the device of preds + + Returns: + stoi value of shape [...] + + Raises: + ValueError: + If ``pystoi`` package is not installed + + Example: + >>> from torchmetrics.functional.audio import stoi + >>> import torch + >>> g = torch.manual_seed(1) + >>> preds = torch.randn(8000) + >>> target = torch.randn(8000) + >>> stoi(preds, target, 8000).float() + tensor(-0.0100) + + References: + [1] https://github.com/mpariente/pystoi + + [2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time Objective Intelligibility Measure for + Time-Frequency Weighted Noisy Speech', ICASSP 2010, Texas, Dallas. + + [3] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for Intelligibility Prediction of + Time-Frequency Weighted Noisy Speech', IEEE Transactions on Audio, Speech, and Language Processing, 2011. + + [4] J. Jensen and C. H. Taal, 'An Algorithm for Predicting the Intelligibility of Speech Masked by Modulated + Noise Maskers', IEEE Transactions on Audio, Speech and Language Processing, 2016. + + """ + if not _PYSTOI_AVAILABLE: + raise ValueError( + "STOI metric requires that pystoi is installed." + "Either install as `pip install torchmetrics[audio]` or `pip install pystoi`" + ) + _check_same_shape(preds, target) + + if len(preds.shape) == 1: + stoi_val_np = stoi_backend(target.detach().cpu().numpy(), preds.detach().cpu().numpy(), fs, extended) + stoi_val = torch.tensor(stoi_val_np) + else: + preds_np = preds.reshape(-1, preds.shape[-1]).detach().cpu().numpy() + target_np = target.reshape(-1, preds.shape[-1]).detach().cpu().numpy() + stoi_val_np = np.empty(shape=(preds_np.shape[0])) + for b in range(preds_np.shape[0]): + stoi_val_np[b] = stoi_backend(target_np[b, :], preds_np[b, :], fs, extended) + stoi_val = torch.from_numpy(stoi_val_np) + stoi_val = stoi_val.reshape(preds.shape[:-1]) + + if keep_same_device: + stoi_val = stoi_val.to(preds.device) + + return stoi_val diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 909b11a4529..0f61d45860d 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -85,3 +85,4 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] _PESQ_AVAILABLE: bool = _module_available("pesq") _SACREBLEU_AVAILABLE: bool = _module_available("sacrebleu") _REGEX_AVAILABLE: bool = _module_available("regex") +_PYSTOI_AVAILABLE: bool = _module_available("pystoi")