Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add speech-to-reverberation modulation energy ratio (SRMR) metric #1792

Merged
merged 83 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
16f92f3
+SRMR
quancs May 16, 2023
08155e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2023
f5843ee
Merge branch 'master' into srmr
SkafteNicki May 22, 2023
8e1f6f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2023
8c370e4
update
quancs Jun 9, 2023
e8fe6ea
update
quancs Jun 9, 2023
8c6fa8a
Merge branch 'srmr' of https://github.com/quancs/torchmetrics into srmr
quancs Jun 9, 2023
698c6a4
update
quancs Jun 9, 2023
14b7ec1
update
quancs Jun 9, 2023
b7317be
Merge branch 'master' into srmr
quancs Jun 9, 2023
50f0511
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2023
ea46e37
fix
quancs Jun 9, 2023
2ee6b47
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2023
cb7c5f7
fix
quancs Jun 9, 2023
f6fead7
fix
quancs Jun 9, 2023
f571d97
fix
quancs Jun 9, 2023
f2557af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2023
c34ad65
update
quancs Jun 9, 2023
9e33963
disbale differentiable test
quancs Jun 9, 2023
fc11680
Merge branch 'srmr' of https://github.com/quancs/torchmetrics into srmr
quancs Jun 9, 2023
ded7452
update
quancs Jun 9, 2023
e7d64a9
fix
quancs Jun 9, 2023
1f1968d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2023
7b23677
fix
quancs Jun 9, 2023
c878bae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2023
1b66bc7
change to lower letters
quancs Jun 9, 2023
c1b7b4c
merge
quancs Jun 9, 2023
3c14017
update
quancs Jun 9, 2023
7d58c56
Merge branch 'srmr' of https://github.com/quancs/torchmetrics into srmr
quancs Jun 9, 2023
c1655e6
fix
quancs Jun 9, 2023
df168db
fix
quancs Jun 9, 2023
9c4e64e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2023
3dfe378
remove assert
quancs Jun 9, 2023
a347417
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2023
d72d90b
shit pre-commit
quancs Jun 9, 2023
78dbe74
Merge branch 'srmr' of https://github.com/quancs/torchmetrics into srmr
quancs Jun 9, 2023
b244bf2
fix
quancs Jun 9, 2023
809f777
Merge branch 'srmr' of https://github.com/quancs/torchmetrics into srmr
quancs Jun 9, 2023
7761417
fix
quancs Jun 9, 2023
df5cbdc
fix
quancs Jun 9, 2023
0ed8dae
fix
quancs Jun 9, 2023
4f48e61
fix
quancs Jun 9, 2023
a6fa842
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2023
b976ab0
fix
quancs Jun 12, 2023
c4489c9
srmrpy
quancs Jun 12, 2023
345a818
update gammatone
quancs Jun 12, 2023
3645aaf
Merge branch 'master' into srmr
SkafteNicki Jun 13, 2023
892ad2e
Merge branch 'master' into srmr
SkafteNicki Jun 15, 2023
042da5d
Update src/torchmetrics/audio/srmr.py
quancs Jun 15, 2023
84992d4
Update src/torchmetrics/audio/srmr.py
quancs Jun 15, 2023
39abb66
Update src/torchmetrics/audio/srmr.py
quancs Jun 15, 2023
4752264
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2023
6133b2f
Update src/torchmetrics/functional/audio/srmr.py
quancs Jun 15, 2023
1ede251
Update tests/unittests/audio/test_srmr.py
quancs Jun 15, 2023
d6734ea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2023
3546ef4
add _srmr_arg_validate
quancs Jun 15, 2023
278efda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2023
9d1c8b9
Merge branch 'master' into srmr
SkafteNicki Jun 16, 2023
42404ff
fix ruff issues
SkafteNicki Jun 16, 2023
cc56270
add plot testing
SkafteNicki Jun 16, 2023
bd57e42
remove gammatone in requirements
quancs Jun 17, 2023
18b4698
fix doc
quancs Jun 17, 2023
9263be7
fix pi
quancs Jun 17, 2023
ea9c811
fix
quancs Jun 17, 2023
48e0fc2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2023
e066f71
add docs for lfilter
quancs Jun 17, 2023
84737a7
update
quancs Jun 17, 2023
cc289dc
Merge branch 'master' into srmr
quancs Jun 17, 2023
5598f87
add gammatone
quancs Jun 17, 2023
06a5826
torchaudio>=0.10.0
quancs Jun 17, 2023
221dd66
skip testing on missinig install
SkafteNicki Jun 19, 2023
fae0bc8
fix imports during testing
SkafteNicki Jun 19, 2023
70b628d
fix plot import
SkafteNicki Jun 19, 2023
8962749
skip conditional plot testing
SkafteNicki Jun 19, 2023
8242f5c
srmr
SkafteNicki Jun 19, 2023
968b391
skipping
SkafteNicki Jun 19, 2023
03c1f6b
Merge branch 'master' into srmr
SkafteNicki Jun 20, 2023
6521d4b
Merge branch 'master' into srmr
SkafteNicki Jun 28, 2023
bce9b9c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2023
0b370a1
fix formatting
SkafteNicki Jun 28, 2023
6ce353c
Merge branch 'master' into srmr
Borda Jun 29, 2023
88e99f9
Merge branch 'master' into srmr
Borda Jun 29, 2023
1e8145a
Update src/torchmetrics/audio/srmr.py
SkafteNicki Jun 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added speech-to-reverberation modulation energy ratio (SRMR) metric ([#1792](https://github.com/Lightning-AI/torchmetrics/pull/1792))


- Added new global arg `compute_with_cache` to control caching behaviour after `compute` method ([#1754](https://github.com/Lightning-AI/torchmetrics/pull/1754))


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Speech-to-Reverberation Modulation Energy Ratio (SRMR)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg
:tags: Audio

.. include:: ../links.rst

######################################################
Speech-to-Reverberation Modulation Energy Ratio (SRMR)
######################################################

Module Interface
________________

.. autoclass:: torchmetrics.audio.srmr.SpeechReverberationModulationEnergyRatio
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.audio.srmr.speech_reverberation_modulation_energy_ratio
:noindex:
3 changes: 3 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@
.. _Scale-invariant signal-to-noise ratio: https://arxiv.org/abs/1711.00541
.. _Complex scale-invariant signal-to-noise ratio: https://arxiv.org/abs/2011.09162
.. _Signal-to-noise ratio: https://arxiv.org/abs/1811.02508
.. _Speech-to-Reverberation Modulation Energy Ratio: https://ieeexplore.ieee.org/document/5547575
.. _SRMRToolbox: https://github.com/MuSAELab/SRMRToolbox
.. _SRMRpy: https://github.com/jfsantos/SRMRpy
.. _Permutation invariant training: https://arxiv.org/abs/1607.00325
.. _ranking ref1: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34
.. _Spectral Distortion Index: https://www.ingentaconnect.com/content/asprs/pers/2008/00000074/00000002/art00003;jsessionid=nzjnb3v9xxr1.x-ic-live-03
Expand Down
2 changes: 2 additions & 0 deletions requirements/audio.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
# this need to be the same as used inside speechmetrics
pesq @ git+https://github.com/ludlows/python-pesq
pystoi <=0.3.3
torchaudio >=0.10.0
gammatone @ https://github.com/detly/gammatone/archive/master.zip#egg=Gammatone
1 change: 1 addition & 0 deletions requirements/audio_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ mir-eval >=0.6, <=0.7
speechmetrics @ git+https://github.com/aliutkus/speechmetrics
fast-bss-eval >=0.1.0, <0.1.5
torch_complex <=0.4.3 # needed for fast-bss-eval
srmrpy @ git+https://github.com/jfsantos/SRMRpy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note this wont be installed with pypi

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turned out this repo is not tested, so we shall not rely with parity on their numbers
ref: jfsantos/SRMRpy#15

13 changes: 12 additions & 1 deletion src/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
ScaleInvariantSignalNoiseRatio,
SignalNoiseRatio,
)
from torchmetrics.utilities.imports import _PESQ_AVAILABLE, _PYSTOI_AVAILABLE
from torchmetrics.utilities.imports import (
_GAMMATONE_AVAILABEL,
_PESQ_AVAILABLE,
_PYSTOI_AVAILABLE,
_TORCHAUDIO_AVAILABEL,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)

__all__ = [
"PermutationInvariantTraining",
Expand All @@ -38,3 +44,8 @@
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility

__all__.append("ShortTimeObjectiveIntelligibility")

if _GAMMATONE_AVAILABEL and _TORCHAUDIO_AVAILABEL and _TORCHAUDIO_GREATER_EQUAL_0_10:
from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio

__all__.append("SpeechReverberationModulationEnergyRatio")
179 changes: 179 additions & 0 deletions src/torchmetrics/audio/srmr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union

from torch import Tensor, tensor

from torchmetrics.functional.audio.srmr import (
_srmr_arg_validate,
speech_reverberation_modulation_energy_ratio,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import (
_GAMMATONE_AVAILABEL,
_MATPLOTLIB_AVAILABLE,
_TORCHAUDIO_AVAILABEL,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not all([_GAMMATONE_AVAILABEL, _TORCHAUDIO_AVAILABEL, _TORCHAUDIO_GREATER_EQUAL_0_10]):
__doctest_skip__ = ["SpeechReverberationModulationEnergyRatio", "SpeechReverberationModulationEnergyRatio.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SpeechReverberationModulationEnergyRatio.plot"]


class SpeechReverberationModulationEnergyRatio(Metric):
"""Calculate `Speech-to-Reverberation Modulation Energy Ratio`_ (SRMR).

SRMR is a non-intrusive metric for speech quality and intelligibility based on
a modulation spectral representation of the speech signal.
This code is translated from `SRMRToolbox`_ and `SRMRpy`_.

As input to ``forward`` and ``update`` the metric accepts the following input

- ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``

As output of `forward` and `compute` the metric returns the following output

- ``srmr`` (:class:`~torch.Tensor`): float scaler tensor

.. note:: using this metrics requires you to have ``gammatone`` and ``torchaudio`` installed.
Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
and ``pip install git+https://github.com/detly/gammatone``.

Args:
fs: the sampling rate
n_cochlear_filters: Number of filters in the acoustic filterbank
low_freq: determines the frequency cutoff for the corresponding gammatone filterbank.
min_cf: Center frequency in Hz of the first modulation filter.
max_cf: Center frequency in Hz of the last modulation filter. If None is given,
then 30 Hz will be used for `norm==False`, otherwise 128 Hz will be used.
norm: Use modulation spectrum energy normalization
fast: Use the faster version based on the gammatonegram.
Note: this argument is inherited from `SRMRpy`_. As the translated code is based to pytorch,
setting `fast=True` may slow down the speed for calculating this metric on GPU.

Raises:
ModuleNotFoundError:
If ``gammatone`` or ``torchaudio`` package is not installed

Example:
>>> import torch
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> srmr = SpeechReverberationModulationEnergyRatio(8000)
>>> srmr(preds)
tensor(0.3354)
"""

msum: Tensor
total: Tensor
full_state_update: bool = False
is_differentiable: bool = True
higher_is_better: bool = True
plot_lower_bound: Optional[float] = None
plot_upper_bound: Optional[float] = None

def __init__(
self,
fs: int,
n_cochlear_filters: int = 23,
low_freq: float = 125,
min_cf: float = 4,
max_cf: Optional[float] = 128,
norm: bool = False,
fast: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not _TORCHAUDIO_AVAILABEL or not _TORCHAUDIO_GREATER_EQUAL_0_10 or not _GAMMATONE_AVAILABEL:
raise ModuleNotFoundError(
"speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and"
" `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "
"``pip install torchaudio>=0.10`` and ``pip install git+https://github.com/detly/gammatone``"
)
_srmr_arg_validate(
fs=fs,
n_cochlear_filters=n_cochlear_filters,
low_freq=low_freq,
min_cf=min_cf,
max_cf=max_cf,
norm=norm,
fast=fast,
)

self.fs = fs
quancs marked this conversation as resolved.
Show resolved Hide resolved
self.n_cochlear_filters = n_cochlear_filters
quancs marked this conversation as resolved.
Show resolved Hide resolved
self.low_freq = low_freq
self.min_cf = min_cf
self.max_cf = max_cf
self.norm = norm
self.fast = fast
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

self.add_state("msum", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor) -> None:
"""Update state with predictions."""
metric_val_batch = speech_reverberation_modulation_energy_ratio(
preds, self.fs, self.n_cochlear_filters, self.low_freq, self.min_cf, self.max_cf, self.norm, self.fast
).to(self.msum.device)

self.msum += metric_val_batch.sum()
self.total += metric_val_batch.numel()

def compute(self) -> Tensor:
"""Compute metric."""
return self.msum / self.total

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> metric = SpeechReverberationModulationEnergyRatio(8000)
>>> metric.update(torch.rand(8000))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> metric = SpeechReverberationModulationEnergyRatio(8000)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(8000)))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
13 changes: 12 additions & 1 deletion src/torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
scale_invariant_signal_noise_ratio,
signal_noise_ratio,
)
from torchmetrics.utilities.imports import _PESQ_AVAILABLE, _PYSTOI_AVAILABLE
from torchmetrics.utilities.imports import (
_GAMMATONE_AVAILABEL,
_PESQ_AVAILABLE,
_PYSTOI_AVAILABLE,
_TORCHAUDIO_AVAILABEL,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)

__all__ = [
"permutation_invariant_training",
Expand All @@ -39,3 +45,8 @@
from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility

__all__.append("short_time_objective_intelligibility")

if _GAMMATONE_AVAILABEL and _TORCHAUDIO_AVAILABEL and _TORCHAUDIO_GREATER_EQUAL_0_10:
from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio

__all__.append("speech_reverberation_modulation_energy_ratio")
Loading
Loading