diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml index ab81cbac4b1..66abccbd121 100644 --- a/.github/workflows/docs-check.yml +++ b/.github/workflows/docs-check.yml @@ -37,9 +37,10 @@ jobs: working-directory: ./docs run: | # First run the same pipeline as Read-The-Docs - apt-get update && sudo apt-get install -y cmake + sudo apt-get update && sudo apt-get install -y cmake make doctest make coverage + shell: bash make-docs: runs-on: ubuntu-20.04 diff --git a/CHANGELOG.md b/CHANGELOG.md index 4eebc8c6cb5..597f1ee7e3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `MatthewsCorrcoef` -> `MatthewsCorrCoef` * `PearsonCorrcoef` -> `PearsonCorrCoef` * `SpearmanCorrcoef` -> `SpearmanCorrCoef` +- Renamed audio STOI metric `audio.STOI` to `audio.ShortTermObjectiveIntelligibility` ([#753](https://github.com/PyTorchLightning/metrics/pull/753)) - Renamed audio PESQ metrics: ([#751](https://github.com/PyTorchLightning/metrics/pull/751)) * `functional.audio.pesq` -> `functional.audio.perceptual_evaluation_speech_quality` * `audio.PESQ` -> `audio.PerceptualEvaluationSpeechQuality` diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 95bffbdb451..6011ceb4b93 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -111,10 +111,10 @@ SignalNoiseRatio .. autoclass:: torchmetrics.SignalNoiseRatio :noindex: -STOI -~~~~ +ShortTermObjectiveIntelligibility +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: torchmetrics.audio.stoi.STOI +.. autoclass:: torchmetrics.audio.stoi.ShortTermObjectiveIntelligibility :noindex: diff --git a/tests/audio/test_stoi.py b/tests/audio/test_stoi.py index cd4192e83d7..5d437de48fa 100644 --- a/tests/audio/test_stoi.py +++ b/tests/audio/test_stoi.py @@ -21,7 +21,7 @@ from tests.helpers import seed_all from tests.helpers.testers import MetricTester -from torchmetrics.audio.stoi import STOI +from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility from torchmetrics.functional.audio.stoi import stoi from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 @@ -82,7 +82,7 @@ def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_st ddp, preds, target, - STOI, + ShortTermObjectiveIntelligibility, sk_metric=partial(average_metric, metric_func=sk_metric), dist_sync_on_step=dist_sync_on_step, metric_args=dict(fs=fs, extended=extended), @@ -101,7 +101,7 @@ def test_stoi_differentiability(self, preds, target, sk_metric, fs, extended): self.run_differentiability_test( preds=preds, target=target, - metric_module=STOI, + metric_module=ShortTermObjectiveIntelligibility, metric_functional=stoi, metric_args=dict(fs=fs, extended=extended), ) @@ -117,13 +117,13 @@ def test_stoi_half_gpu(self, preds, target, sk_metric, fs, extended): self.run_precision_test_gpu( preds=preds, target=target, - metric_module=STOI, + metric_module=ShortTermObjectiveIntelligibility, metric_functional=partial(stoi, fs=fs, extended=extended), metric_args=dict(fs=fs, extended=extended), ) -def test_error_on_different_shape(metric_class=STOI): +def test_error_on_different_shape(metric_class=ShortTermObjectiveIntelligibility): metric = metric_class(16000) with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): metric(torch.randn(100), torch.randn(50)) diff --git a/torchmetrics/audio/__init__.py b/torchmetrics/audio/__init__.py index 22e5b17c93d..50f54d042b9 100644 --- a/torchmetrics/audio/__init__.py +++ b/torchmetrics/audio/__init__.py @@ -16,7 +16,10 @@ from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401 from torchmetrics.audio.si_snr import SI_SNR # noqa: F401 from torchmetrics.audio.snr import SNR, ScaleInvariantSignalNoiseRatio, SignalNoiseRatio # noqa: F401 -from torchmetrics.utilities.imports import _PESQ_AVAILABLE +from torchmetrics.utilities.imports import _PESQ_AVAILABLE, _PYSTOI_AVAILABLE if _PESQ_AVAILABLE: from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality # noqa: F401 + +if _PYSTOI_AVAILABLE: + from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility # noqa: F401 diff --git a/torchmetrics/audio/stoi.py b/torchmetrics/audio/stoi.py index 67eb3d28e83..5ecffbcfbfa 100644 --- a/torchmetrics/audio/stoi.py +++ b/torchmetrics/audio/stoi.py @@ -13,14 +13,16 @@ # limitations under the License. from typing import Any, Callable, Optional +from deprecate import deprecated, void from torch import Tensor, tensor from torchmetrics.functional.audio.stoi import stoi from torchmetrics.metric import Metric +from torchmetrics.utilities import _future_warning from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE -class STOI(Metric): +class ShortTermObjectiveIntelligibility(Metric): r"""STOI (Short Term Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1]. Note that input will be moved to `cpu` to perform the metric calculation. @@ -63,12 +65,12 @@ class STOI(Metric): If ``pystoi`` package is not installed Example: - >>> from torchmetrics.audio.stoi import STOI + >>> from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) - >>> stoi = STOI(8000, False) + >>> stoi = ShortTermObjectiveIntelligibility(8000, False) >>> stoi(preds, target) tensor(-0.0100) @@ -131,3 +133,32 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore def compute(self) -> Tensor: """Computes average STOI.""" return self.sum_stoi / self.total + + +class STOI(ShortTermObjectiveIntelligibility): + r"""STOI (Short Term Objective Intelligibility), a wrapper for the pystoi package. + + .. deprecated:: v0.7 + Use :class:`torchmetrics.audio.ShortTermObjectiveIntelligibility`. Will be removed in v0.8. + + Example: + >>> import torch + >>> g = torch.manual_seed(1) + >>> preds = torch.randn(8000) + >>> target = torch.randn(8000) + >>> stoi = STOI(8000, False) + >>> stoi(preds, target) + tensor(-0.0100) + """ + + @deprecated(target=ShortTermObjectiveIntelligibility, deprecated_in="0.7", remove_in="0.8", stream=_future_warning) + def __init__( + self, + fs: int, + extended: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None, + ) -> None: + void(fs, extended, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) diff --git a/torchmetrics/functional/audio/stoi.py b/torchmetrics/functional/audio/stoi.py index ff8acb12c96..17b98056468 100644 --- a/torchmetrics/functional/audio/stoi.py +++ b/torchmetrics/functional/audio/stoi.py @@ -82,7 +82,7 @@ def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_sa """ if not _PYSTOI_AVAILABLE: raise ModuleNotFoundError( - "STOI metric requires that `pystoi` is installed." + "ShortTermObjectiveIntelligibility metric requires that `pystoi` is installed." " Either install as `pip install torchmetrics[audio]` or `pip install pystoi`." ) _check_same_shape(preds, target)