Skip to content

Commit

Permalink
rename STOI (#753)
Browse files Browse the repository at this point in the history
* ShortTermObjectiveIntelligibility
* docs
  • Loading branch information
Borda authored Jan 14, 2022
1 parent 06d820b commit d044d33
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 14 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/docs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ jobs:
working-directory: ./docs
run: |
# First run the same pipeline as Read-The-Docs
apt-get update && sudo apt-get install -y cmake
sudo apt-get update && sudo apt-get install -y cmake
make doctest
make coverage
shell: bash

make-docs:
runs-on: ubuntu-20.04
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `MatthewsCorrcoef` -> `MatthewsCorrCoef`
* `PearsonCorrcoef` -> `PearsonCorrCoef`
* `SpearmanCorrcoef` -> `SpearmanCorrCoef`
- Renamed audio STOI metric `audio.STOI` to `audio.ShortTermObjectiveIntelligibility` ([#753](https://github.com/PyTorchLightning/metrics/pull/753))
- Renamed audio PESQ metrics: ([#751](https://github.com/PyTorchLightning/metrics/pull/751))
* `functional.audio.pesq` -> `functional.audio.perceptual_evaluation_speech_quality`
* `audio.PESQ` -> `audio.PerceptualEvaluationSpeechQuality`
Expand Down
6 changes: 3 additions & 3 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ SignalNoiseRatio
.. autoclass:: torchmetrics.SignalNoiseRatio
:noindex:

STOI
~~~~
ShortTermObjectiveIntelligibility
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.audio.stoi.STOI
.. autoclass:: torchmetrics.audio.stoi.ShortTermObjectiveIntelligibility
:noindex:


Expand Down
10 changes: 5 additions & 5 deletions tests/audio/test_stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from torchmetrics.audio.stoi import STOI
from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility
from torchmetrics.functional.audio.stoi import stoi
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

Expand Down Expand Up @@ -82,7 +82,7 @@ def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_st
ddp,
preds,
target,
STOI,
ShortTermObjectiveIntelligibility,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(fs=fs, extended=extended),
Expand All @@ -101,7 +101,7 @@ def test_stoi_differentiability(self, preds, target, sk_metric, fs, extended):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=STOI,
metric_module=ShortTermObjectiveIntelligibility,
metric_functional=stoi,
metric_args=dict(fs=fs, extended=extended),
)
Expand All @@ -117,13 +117,13 @@ def test_stoi_half_gpu(self, preds, target, sk_metric, fs, extended):
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=STOI,
metric_module=ShortTermObjectiveIntelligibility,
metric_functional=partial(stoi, fs=fs, extended=extended),
metric_args=dict(fs=fs, extended=extended),
)


def test_error_on_different_shape(metric_class=STOI):
def test_error_on_different_shape(metric_class=ShortTermObjectiveIntelligibility):
metric = metric_class(16000)
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
Expand Down
5 changes: 4 additions & 1 deletion torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401
from torchmetrics.audio.si_snr import SI_SNR # noqa: F401
from torchmetrics.audio.snr import SNR, ScaleInvariantSignalNoiseRatio, SignalNoiseRatio # noqa: F401
from torchmetrics.utilities.imports import _PESQ_AVAILABLE
from torchmetrics.utilities.imports import _PESQ_AVAILABLE, _PYSTOI_AVAILABLE

if _PESQ_AVAILABLE:
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality # noqa: F401

if _PYSTOI_AVAILABLE:
from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility # noqa: F401
37 changes: 34 additions & 3 deletions torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
# limitations under the License.
from typing import Any, Callable, Optional

from deprecate import deprecated, void
from torch import Tensor, tensor

from torchmetrics.functional.audio.stoi import stoi
from torchmetrics.metric import Metric
from torchmetrics.utilities import _future_warning
from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE


class STOI(Metric):
class ShortTermObjectiveIntelligibility(Metric):
r"""STOI (Short Term Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
Note that input will be moved to `cpu` to perform the metric calculation.
Expand Down Expand Up @@ -63,12 +65,12 @@ class STOI(Metric):
If ``pystoi`` package is not installed
Example:
>>> from torchmetrics.audio.stoi import STOI
>>> from torchmetrics.audio.stoi import ShortTermObjectiveIntelligibility
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> stoi = STOI(8000, False)
>>> stoi = ShortTermObjectiveIntelligibility(8000, False)
>>> stoi(preds, target)
tensor(-0.0100)
Expand Down Expand Up @@ -131,3 +133,32 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
def compute(self) -> Tensor:
"""Computes average STOI."""
return self.sum_stoi / self.total


class STOI(ShortTermObjectiveIntelligibility):
r"""STOI (Short Term Objective Intelligibility), a wrapper for the pystoi package.
.. deprecated:: v0.7
Use :class:`torchmetrics.audio.ShortTermObjectiveIntelligibility`. Will be removed in v0.8.
Example:
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> stoi = STOI(8000, False)
>>> stoi(preds, target)
tensor(-0.0100)
"""

@deprecated(target=ShortTermObjectiveIntelligibility, deprecated_in="0.7", remove_in="0.8", stream=_future_warning)
def __init__(
self,
fs: int,
extended: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
) -> None:
void(fs, extended, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
2 changes: 1 addition & 1 deletion torchmetrics/functional/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def stoi(preds: Tensor, target: Tensor, fs: int, extended: bool = False, keep_sa
"""
if not _PYSTOI_AVAILABLE:
raise ModuleNotFoundError(
"STOI metric requires that `pystoi` is installed."
"ShortTermObjectiveIntelligibility metric requires that `pystoi` is installed."
" Either install as `pip install torchmetrics[audio]` or `pip install pystoi`."
)
_check_same_shape(preds, target)
Expand Down

0 comments on commit d044d33

Please sign in to comment.