Skip to content

Commit

Permalink
Merge branch 'master' into pairwise
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Oct 14, 2021
2 parents 9bc02bb + 8bbc750 commit 51b81b4
Show file tree
Hide file tree
Showing 12 changed files with 408 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `pairwise_manhatten_distance`


- Added Short Term Objective Intelligibility (`STOI`) ([#353](https://github.com/PyTorchLightning/metrics/issues/353))


### Changed

- `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))
Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ snr [func]
:noindex:


stoi [func]
~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.stoi
:noindex:


**********************
Classification Metrics
**********************
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ SNR
.. autoclass:: torchmetrics.SNR
:noindex:

STOI
~~~~

.. autoclass:: torchmetrics.STOI
:noindex:


**********************
Classification Metrics
Expand Down
1 change: 1 addition & 0 deletions requirements/audio.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pesq>=0.0.3
pystoi
146 changes: 146 additions & 0 deletions tests/audio/test_stoi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial

import pytest
import torch
from pystoi import stoi as stoi_backend
from torch import Tensor

from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from torchmetrics.audio import STOI
from torchmetrics.functional import stoi
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])

inputs_8k = Input(
preds=torch.rand(2, 3, 8000),
target=torch.rand(2, 3, 8000),
)
inputs_16k = Input(
preds=torch.rand(2, 3, 16000),
target=torch.rand(2, 3, 16000),
)


def stoi_original_batch(preds: Tensor, target: Tensor, fs: int, extended: bool):
# shape: preds [BATCH_SIZE, Time] , target [BATCH_SIZE, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, Time] , target [NUM_BATCHES*BATCH_SIZE, Time]
target = target.detach().cpu().numpy()
preds = preds.detach().cpu().numpy()
mss = []
for b in range(preds.shape[0]):
pesq_val = stoi_backend(target[b, ...], preds[b, ...], fs, extended)
mss.append(pesq_val)
return torch.tensor(mss)


def average_metric(preds, target, metric_func):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return metric_func(preds, target).mean()


stoi_original_batch_8k_ext = partial(stoi_original_batch, fs=8000, extended=True)
stoi_original_batch_16k_ext = partial(stoi_original_batch, fs=16000, extended=True)
stoi_original_batch_8k_noext = partial(stoi_original_batch, fs=8000, extended=False)
stoi_original_batch_16k_noext = partial(stoi_original_batch, fs=16000, extended=False)


@pytest.mark.parametrize(
"preds, target, sk_metric, fs, extended",
[
(inputs_8k.preds, inputs_8k.target, stoi_original_batch_8k_ext, 8000, True),
(inputs_16k.preds, inputs_16k.target, stoi_original_batch_16k_ext, 16000, True),
(inputs_8k.preds, inputs_8k.target, stoi_original_batch_8k_noext, 8000, False),
(inputs_16k.preds, inputs_16k.target, stoi_original_batch_16k_noext, 16000, False),
],
)
class TestSTOI(MetricTester):
atol = 1e-2

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_stoi(self, preds, target, sk_metric, fs, extended, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
STOI,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(fs=fs, extended=extended),
)

def test_stoi_functional(self, preds, target, sk_metric, fs, extended):
self.run_functional_metric_test(
preds,
target,
stoi,
sk_metric,
metric_args=dict(fs=fs, extended=extended),
)

def test_stoi_differentiability(self, preds, target, sk_metric, fs, extended):
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=STOI,
metric_functional=stoi,
metric_args=dict(fs=fs, extended=extended),
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason="half support of core operations on not support before pytorch v1.6"
)
def test_stoi_half_cpu(self, preds, target, sk_metric, fs, extended):
pytest.xfail("STOI metric does not support cpu + half precision")

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_stoi_half_gpu(self, preds, target, sk_metric, fs, extended):
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=STOI,
metric_functional=partial(stoi, fs=fs, extended=extended),
metric_args=dict(fs=fs, extended=extended),
)


def test_error_on_different_shape(metric_class=STOI):
metric = metric_class(16000)
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))


def test_on_real_audio():
import os

from scipy.io import wavfile

current_file_dir = os.path.dirname(__file__)

rate, ref = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech.wav"))
rate, deg = wavfile.read(os.path.join(current_file_dir, "examples/audio_speech_bab_0dB.wav"))
assert torch.allclose(
stoi(torch.from_numpy(deg), torch.from_numpy(ref), rate).float(),
torch.tensor(0.6739177),
rtol=0.0001,
atol=1e-4,
)
3 changes: 2 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from torchmetrics import functional # noqa: E402
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402
from torchmetrics.audio import PESQ, PIT, SI_SDR, SI_SNR, SNR # noqa: E402
from torchmetrics.audio import PESQ, PIT, SI_SDR, SI_SNR, SNR, STOI # noqa: E402
from torchmetrics.classification import ( # noqa: E402
AUC,
AUROC,
Expand Down Expand Up @@ -131,6 +131,7 @@
"Specificity",
"SSIM",
"StatScores",
"STOI",
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"WER",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401
from torchmetrics.audio.si_snr import SI_SNR # noqa: F401
from torchmetrics.audio.snr import SNR # noqa: F401
from torchmetrics.audio.stoi import STOI # noqa: F401
133 changes: 133 additions & 0 deletions torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

from torch import Tensor, tensor

from torchmetrics.functional.audio.stoi import stoi
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE


class STOI(Metric):
r"""STOI (Short Term Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1].
Note that input will be moved to `cpu` to perform the metric calculation.
Intelligibility measure which is highly correlated with the intelligibility of degraded speech signals, e.g., due
to additive noise, single/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations.
The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good
alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are
interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms,
on speech intelligibility. Description taken from [Cees Taal's website](http://www.ceestaal.nl/code/).
.. note:: using this metrics requires you to have ``pystoi`` install. Either install as ``pip install
torchmetrics[audio]`` or ``pip install pystoi``
Forward accepts
- ``preds``: ``shape [...,time]``
- ``target``: ``shape [...,time]``
Args:
fs:
sampling frequency (Hz)
extended:
whether to use the extended STOI described in [4]
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather.
Returns:
average STOI value
Raises:
ModuleNotFoundError:
If ``pystoi`` package is not installed
Example:
>>> from torchmetrics.audio import STOI
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> stoi = STOI(8000, False)
>>> stoi(preds, target)
tensor(-0.0100)
References:
[1] https://github.com/mpariente/pystoi
[2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time Objective Intelligibility Measure for
Time-Frequency Weighted Noisy Speech', ICASSP 2010, Texas, Dallas.
[3] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for Intelligibility Prediction of
Time-Frequency Weighted Noisy Speech', IEEE Transactions on Audio, Speech, and Language Processing, 2011.
[4] J. Jensen and C. H. Taal, 'An Algorithm for Predicting the Intelligibility of Speech Masked by Modulated
Noise Maskers', IEEE Transactions on Audio, Speech and Language Processing, 2016.
"""
sum_stoi: Tensor
total: Tensor
is_differentiable = False
higher_is_better = True

def __init__(
self,
fs: int,
extended: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
if not _PYSTOI_AVAILABLE:
raise ModuleNotFoundError(
"STOI metric requires that pystoi is installed."
" Either install as `pip install torchmetrics[audio]` or `pip install pystoi`"
)
self.fs = fs
self.extended = extended

self.add_state("sum_stoi", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
stoi_batch = stoi(preds, target, self.fs, self.extended, False).to(self.sum_stoi.device)

self.sum_stoi += stoi_batch.sum()
self.total += stoi_batch.numel()

def compute(self) -> Tensor:
"""Computes average STOI."""
return self.sum_stoi / self.total
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchmetrics.functional.audio.si_sdr import si_sdr
from torchmetrics.functional.audio.si_snr import si_snr
from torchmetrics.functional.audio.snr import snr
from torchmetrics.functional.audio.stoi import stoi
from torchmetrics.functional.classification.accuracy import accuracy
from torchmetrics.functional.classification.auc import auc
from torchmetrics.functional.classification.auroc import auroc
Expand Down Expand Up @@ -127,6 +128,7 @@
"specificity",
"ssim",
"stat_scores",
"stoi",
"symmetric_mean_absolute_percentage_error",
"wer",
]
1 change: 1 addition & 0 deletions torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from torchmetrics.functional.audio.si_sdr import si_sdr # noqa: F401
from torchmetrics.functional.audio.si_snr import si_snr # noqa: F401
from torchmetrics.functional.audio.snr import snr # noqa: F401
from torchmetrics.functional.audio.stoi import stoi # noqa: F401
Loading

0 comments on commit 51b81b4

Please sign in to comment.