Skip to content

Commit

Permalink
Add speech-to-reverberation modulation energy ratio (SRMR) metric (#1792
Browse files Browse the repository at this point in the history
)

* +SRMR

* update

* update

* update

* update

* fix

* fix

* fix

* fix

* update

* disbale differentiable test

* update

* fix

* fix

* change to lower letters

* update

* fix

* fix

* remove assert

* shit pre-commit

* fix

* fix

* fix

* fix

* fix

* fix

* srmrpy

* update gammatone

* Update src/torchmetrics/audio/srmr.py

* Update src/torchmetrics/audio/srmr.py

* Update src/torchmetrics/audio/srmr.py

* Update src/torchmetrics/functional/audio/srmr.py

* Update tests/unittests/audio/test_srmr.py

* add _srmr_arg_validate

* fix ruff issues

* add plot testing

* remove gammatone in requirements

* fix doc

* fix pi

* fix

* add docs for lfilter

* update

* add gammatone

* torchaudio>=0.10.0

* skip testing on missinig install

* fix imports during testing

* fix plot import

* skip conditional plot testing

* srmr

* skipping

* fix formatting

* Update src/torchmetrics/audio/srmr.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
4 people authored Jun 29, 2023
1 parent cc9f4de commit 372f22e
Show file tree
Hide file tree
Showing 12 changed files with 728 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added speech-to-reverberation modulation energy ratio (SRMR) metric ([#1792](https://github.com/Lightning-AI/torchmetrics/pull/1792))


- Added new global arg `compute_with_cache` to control caching behaviour after `compute` method ([#1754](https://github.com/Lightning-AI/torchmetrics/pull/1754))


Expand Down
23 changes: 23 additions & 0 deletions docs/source/audio/speech_reverberation_modulation_energy_ratio.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Speech-to-Reverberation Modulation Energy Ratio (SRMR)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg
:tags: Audio

.. include:: ../links.rst

######################################################
Speech-to-Reverberation Modulation Energy Ratio (SRMR)
######################################################

Module Interface
________________

.. autoclass:: torchmetrics.audio.srmr.SpeechReverberationModulationEnergyRatio
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.audio.srmr.speech_reverberation_modulation_energy_ratio
:noindex:
3 changes: 3 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@
.. _Scale-invariant signal-to-noise ratio: https://arxiv.org/abs/1711.00541
.. _Complex scale-invariant signal-to-noise ratio: https://arxiv.org/abs/2011.09162
.. _Signal-to-noise ratio: https://arxiv.org/abs/1811.02508
.. _Speech-to-Reverberation Modulation Energy Ratio: https://ieeexplore.ieee.org/document/5547575
.. _SRMRToolbox: https://github.com/MuSAELab/SRMRToolbox
.. _SRMRpy: https://github.com/jfsantos/SRMRpy
.. _Permutation invariant training: https://arxiv.org/abs/1607.00325
.. _ranking ref1: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34
.. _Spectral Distortion Index: https://www.ingentaconnect.com/content/asprs/pers/2008/00000074/00000002/art00003;jsessionid=nzjnb3v9xxr1.x-ic-live-03
Expand Down
2 changes: 2 additions & 0 deletions requirements/audio.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
# this need to be the same as used inside speechmetrics
pesq @ git+https://github.com/ludlows/python-pesq
pystoi <=0.3.3
torchaudio >=0.10.0
gammatone @ https://github.com/detly/gammatone/archive/master.zip#egg=Gammatone
1 change: 1 addition & 0 deletions requirements/audio_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ mir-eval >=0.6, <=0.7
speechmetrics @ git+https://github.com/aliutkus/speechmetrics
fast-bss-eval >=0.1.0, <0.1.5
torch_complex <=0.4.3 # needed for fast-bss-eval
srmrpy @ git+https://github.com/jfsantos/SRMRpy
13 changes: 12 additions & 1 deletion src/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
ScaleInvariantSignalNoiseRatio,
SignalNoiseRatio,
)
from torchmetrics.utilities.imports import _PESQ_AVAILABLE, _PYSTOI_AVAILABLE
from torchmetrics.utilities.imports import (
_GAMMATONE_AVAILABEL,
_PESQ_AVAILABLE,
_PYSTOI_AVAILABLE,
_TORCHAUDIO_AVAILABEL,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)

__all__ = [
"PermutationInvariantTraining",
Expand All @@ -38,3 +44,8 @@
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility

__all__.append("ShortTimeObjectiveIntelligibility")

if _GAMMATONE_AVAILABEL and _TORCHAUDIO_AVAILABEL and _TORCHAUDIO_GREATER_EQUAL_0_10:
from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio

__all__.append("SpeechReverberationModulationEnergyRatio")
179 changes: 179 additions & 0 deletions src/torchmetrics/audio/srmr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Union

from torch import Tensor, tensor

from torchmetrics.functional.audio.srmr import (
_srmr_arg_validate,
speech_reverberation_modulation_energy_ratio,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import (
_GAMMATONE_AVAILABEL,
_MATPLOTLIB_AVAILABLE,
_TORCHAUDIO_AVAILABEL,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not all([_GAMMATONE_AVAILABEL, _TORCHAUDIO_AVAILABEL, _TORCHAUDIO_GREATER_EQUAL_0_10]):
__doctest_skip__ = ["SpeechReverberationModulationEnergyRatio", "SpeechReverberationModulationEnergyRatio.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SpeechReverberationModulationEnergyRatio.plot"]


class SpeechReverberationModulationEnergyRatio(Metric):
"""Calculate `Speech-to-Reverberation Modulation Energy Ratio`_ (SRMR).
SRMR is a non-intrusive metric for speech quality and intelligibility based on
a modulation spectral representation of the speech signal.
This code is translated from `SRMRToolbox`_ and `SRMRpy`_.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``
As output of `forward` and `compute` the metric returns the following output
- ``srmr`` (:class:`~torch.Tensor`): float scaler tensor
.. note:: using this metrics requires you to have ``gammatone`` and ``torchaudio`` installed.
Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
and ``pip install git+https://github.com/detly/gammatone``.
Args:
fs: the sampling rate
n_cochlear_filters: Number of filters in the acoustic filterbank
low_freq: determines the frequency cutoff for the corresponding gammatone filterbank.
min_cf: Center frequency in Hz of the first modulation filter.
max_cf: Center frequency in Hz of the last modulation filter. If None is given,
then 30 Hz will be used for `norm==False`, otherwise 128 Hz will be used.
norm: Use modulation spectrum energy normalization
fast: Use the faster version based on the gammatonegram.
Note: this argument is inherited from `SRMRpy`_. As the translated code is based to pytorch,
setting `fast=True` may slow down the speed for calculating this metric on GPU.
Raises:
ModuleNotFoundError:
If ``gammatone`` or ``torchaudio`` package is not installed
Example:
>>> import torch
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> srmr = SpeechReverberationModulationEnergyRatio(8000)
>>> srmr(preds)
tensor(0.3354)
"""

msum: Tensor
total: Tensor
full_state_update: bool = False
is_differentiable: bool = True
higher_is_better: bool = True
plot_lower_bound: Optional[float] = None
plot_upper_bound: Optional[float] = None

def __init__(
self,
fs: int,
n_cochlear_filters: int = 23,
low_freq: float = 125,
min_cf: float = 4,
max_cf: Optional[float] = 128,
norm: bool = False,
fast: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not _TORCHAUDIO_AVAILABEL or not _TORCHAUDIO_GREATER_EQUAL_0_10 or not _GAMMATONE_AVAILABEL:
raise ModuleNotFoundError(
"speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and"
" `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "
"``pip install torchaudio>=0.10`` and ``pip install git+https://github.com/detly/gammatone``"
)
_srmr_arg_validate(
fs=fs,
n_cochlear_filters=n_cochlear_filters,
low_freq=low_freq,
min_cf=min_cf,
max_cf=max_cf,
norm=norm,
fast=fast,
)

self.fs = fs
self.n_cochlear_filters = n_cochlear_filters
self.low_freq = low_freq
self.min_cf = min_cf
self.max_cf = max_cf
self.norm = norm
self.fast = fast

self.add_state("msum", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor) -> None:
"""Update state with predictions."""
metric_val_batch = speech_reverberation_modulation_energy_ratio(
preds, self.fs, self.n_cochlear_filters, self.low_freq, self.min_cf, self.max_cf, self.norm, self.fast
).to(self.msum.device)

self.msum += metric_val_batch.sum()
self.total += metric_val_batch.numel()

def compute(self) -> Tensor:
"""Compute metric."""
return self.msum / self.total

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> metric = SpeechReverberationModulationEnergyRatio(8000)
>>> metric.update(torch.rand(8000))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> metric = SpeechReverberationModulationEnergyRatio(8000)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(8000)))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
13 changes: 12 additions & 1 deletion src/torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
scale_invariant_signal_noise_ratio,
signal_noise_ratio,
)
from torchmetrics.utilities.imports import _PESQ_AVAILABLE, _PYSTOI_AVAILABLE
from torchmetrics.utilities.imports import (
_GAMMATONE_AVAILABEL,
_PESQ_AVAILABLE,
_PYSTOI_AVAILABLE,
_TORCHAUDIO_AVAILABEL,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)

__all__ = [
"permutation_invariant_training",
Expand All @@ -39,3 +45,8 @@
from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility

__all__.append("short_time_objective_intelligibility")

if _GAMMATONE_AVAILABEL and _TORCHAUDIO_AVAILABEL and _TORCHAUDIO_GREATER_EQUAL_0_10:
from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio

__all__.append("speech_reverberation_modulation_energy_ratio")
Loading

0 comments on commit 372f22e

Please sign in to comment.