Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

re-implement the signal_distortion_ratio metric #964

Merged
merged 40 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
eedeecc
fix sdr in 1.12
quancs Mar 23, 2022
f40f332
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2022
5a4273f
Merge branch 'master' into master
mergify[bot] Mar 23, 2022
2f5b59e
Update tests/audio/test_sdr.py
quancs Mar 23, 2022
19a7006
Merge branch 'master' into master
mergify[bot] Mar 23, 2022
a7a0d0b
update
quancs Mar 23, 2022
f1d7541
Merge branch 'PyTorchLightning:master' into master
quancs Apr 17, 2022
2899ddc
reimplement signal_distortion_ratio
quancs Apr 17, 2022
39b37a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2022
1121751
sdr is differentiable for all supported pytorch version now
quancs Apr 17, 2022
e3b1093
Merge branch 'master' of https://github.com/quancs/metrics
quancs Apr 17, 2022
95dfb8b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2022
027443f
update
quancs Apr 17, 2022
1fa2656
Merge branch 'master' of https://github.com/quancs/metrics
quancs Apr 17, 2022
93ced9f
Merge branch 'master' into master
Borda Apr 20, 2022
f6fbceb
chlog
Borda Apr 20, 2022
833901e
Apply suggestions from code review
quancs Apr 20, 2022
40a846e
Apply suggestions from code review
quancs Apr 20, 2022
21ae77a
rename
quancs Apr 20, 2022
1114672
remove _FAST_BSS_EVAL_AVAILABLE
quancs Apr 20, 2022
7e4a595
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 20, 2022
94ad23b
fix too long line
SkafteNicki Apr 20, 2022
8616cb0
Merge branch 'master' into master
mergify[bot] Apr 20, 2022
54418c1
rename
quancs Apr 20, 2022
af9ccac
fix
quancs Apr 20, 2022
3dfaaba
Merge branches 'master' and 'master' of https://github.com/quancs/met…
quancs Apr 20, 2022
436cd64
fft
quancs Apr 21, 2022
18654bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2022
38ec931
fix keepdim
quancs Apr 21, 2022
3a19194
fix solve
quancs Apr 21, 2022
b15743b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2022
a213c19
format
quancs Apr 21, 2022
40798dc
Merge branch 'master' of https://github.com/quancs/metrics
quancs Apr 21, 2022
c433a60
use double precision
quancs Apr 21, 2022
e847d16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2022
8e7cf5f
remove _TORCH_GREATER_EQUAL_1_7
quancs Apr 21, 2022
79cd144
Merge branch 'master' of https://github.com/quancs/metrics
quancs Apr 21, 2022
7732764
fix rfft/irfft
quancs Apr 21, 2022
9d91d3b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2022
9343d5d
remove _TORCH_LOWER_1_12_DEV
quancs Apr 21, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Reimplemented the `signal_distortion_ratio` metric, which removed the absolute requirement of `fast-bss-eval` ([#964](https://github.com/PyTorchLightning/metrics/pull/964))


-
Expand Down
2 changes: 0 additions & 2 deletions requirements/audio.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
pesq>=0.0.3
pystoi
fast-bss-eval>=0.1.0
torch_complex # needed for fast-bss-eval torch<=1.7
2 changes: 2 additions & 0 deletions requirements/audio_test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pypesq
mir_eval>=0.6
speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip
fast-bss-eval>=0.1.0
torch_complex # needed for fast-bss-eval torch<=1.7
7 changes: 3 additions & 4 deletions tests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tests.helpers.testers import MetricTester
from torchmetrics.audio import SignalDistortionRatio
from torchmetrics.functional import signal_distortion_ratio
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8, _TORCH_LOWER_1_12_DEV
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_LOWER_1_12_DEV

seed_all(42)

Expand Down Expand Up @@ -99,7 +99,6 @@ def test_sdr_functional(self, preds, target, sk_metric):
metric_args=dict(),
)

@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_8, reason="sdr is not differentiable for pytorch < 1.8")
def test_sdr_differentiability(self, preds, target, sk_metric):
self.run_differentiability_test(
preds=preds,
Expand Down Expand Up @@ -155,13 +154,13 @@ def test_too_low_precision():
preds = torch.tensor(data["preds"])
target = torch.tensor(data["target"])

if _TORCH_GREATER_EQUAL_1_8 and _TORCH_LOWER_1_12_DEV:
if _TORCH_LOWER_1_12_DEV:
with pytest.warns(
UserWarning,
match="Detected `nan` or `inf` value in computed metric, retrying computation in double precision",
):
sdr_tm = signal_distortion_ratio(preds, target)
else: # when pytorch < 1.8 or pytorch >= 1.12, sdr doesn't have this problem
else: # when pytorch >= 1.12, sdr doesn't have this problem
sdr_tm = signal_distortion_ratio(preds, target).double()

# check equality with bss_eval_sources in every pytorch version
Expand Down
33 changes: 9 additions & 24 deletions torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE

__doctest_requires__ = {"SignalDistortionRatio": ["fast_bss_eval"]}


class SignalDistortionRatio(Metric):
r"""Signal to Distortion Ratio (SDR) [1,2,3]
r"""Signal to Distortion Ratio (SDR) [1,2]

Forward accepts

Expand All @@ -32,10 +31,13 @@ class SignalDistortionRatio(Metric):

Args:
use_cg_iter:
If provided, an iterative method is used to solve for the distortion filter coefficients instead
of direct Gaussian elimination. This can speed up the computation of the metrics in case the filters
are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient
when using this loss to train neural separation networks.
If provided, conjugate gradient descent is used to solve for the distortion
filter coefficients instead of direct Gaussian elimination, which requires that
``fast-bss-eval`` is installed and pytorch version >= 1.8.
This can speed up the computation of the metrics in case the filters
are long. Using a value of 10 here has been shown to provide
good accuracy in most cases and is sufficient when using this
loss to train neural separation networks.
filter_length: The length of the distortion filter allowed
zero_mean:
When set to True, the mean of all signals is subtracted prior to computation of the metrics
Expand All @@ -51,10 +53,6 @@ class SignalDistortionRatio(Metric):

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ModuleNotFoundError:
If ``fast-bss-eval`` package is not installed

Example:
>>> from torchmetrics.audio import SignalDistortionRatio
>>> import torch
Expand All @@ -74,22 +72,14 @@ class SignalDistortionRatio(Metric):
tensor(-11.6051)

.. note::
1. when pytorch<1.8.0, numpy will be used to calculate this metric, which causes ``sdr`` to be
non-differentiable and slower to calculate

2. using this metrics requires you to have ``fast-bss-eval`` install. Either install as ``pip install
torchmetrics[audio]`` or ``pip install fast-bss-eval``

3. preds and target need to have the same dtype, otherwise target will be converted to preds' dtype
Preds and target need to have the same dtype, otherwise target will be converted to preds' dtype


References:
[1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation.
IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469.

[2] Scheibler, R. (2021). SDR -- Medium Rare with Fast Computations.

[3] https://github.com/fakufaku/fast_bss_eval
"""

sum_sdr: Tensor
Expand All @@ -106,11 +96,6 @@ def __init__(
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
if not _FAST_BSS_EVAL_AVAILABLE:
raise ModuleNotFoundError(
"SDR metric requires that `fast-bss-eval` is installed."
" Either install as `pip install torchmetrics[audio]` or `pip install fast-bss-eval`."
)
super().__init__(compute_on_step=compute_on_step, **kwargs)

self.use_cg_iter = use_cg_iter
Expand Down
193 changes: 115 additions & 78 deletions torchmetrics/functional/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,97 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import warnings
from typing import Optional
from typing import Optional, Tuple

import torch
from torch import Tensor

from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE, _TORCH_GREATER_EQUAL_1_8
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8

if _FAST_BSS_EVAL_AVAILABLE:
if _TORCH_GREATER_EQUAL_1_8:
from fast_bss_eval.torch.cgd import toeplitz_conjugate_gradient
from fast_bss_eval.torch.helpers import _normalize
from fast_bss_eval.torch.linalg import toeplitz
from fast_bss_eval.torch.metrics import compute_stats
# import or def the norm function
if _TORCH_GREATER_EQUAL_1_7:
from torch.linalg import norm
else:
from torch import norm

solve = torch.linalg.solve
else:
import numpy
from fast_bss_eval.numpy.cgd import toeplitz_conjugate_gradient
from fast_bss_eval.numpy.helpers import _normalize
from fast_bss_eval.numpy.linalg import toeplitz
from fast_bss_eval.numpy.metrics import compute_stats
# import or redirect the solve function
if _TORCH_GREATER_EQUAL_1_8:
solve = torch.linalg.solve
else:
from torch import solve as _solve

solve = numpy.linalg.solve
def solve(A: Tensor, b: Tensor) -> Tensor:
Borda marked this conversation as resolved.
Show resolved Hide resolved
return _solve(b, A)[0]


if _FAST_BSS_EVAL_AVAILABLE and _TORCH_GREATER_EQUAL_1_8:
from fast_bss_eval.torch.cgd import toeplitz_conjugate_gradient
else:
toeplitz = None
toeplitz_conjugate_gradient = None
compute_stats = None
_normalize = None
__doctest_skip__ = ["signal_distortion_ratio"]

from torch import Tensor

from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.checks import _check_same_shape
def _symmetric_toeplitz(vector: Tensor) -> Tensor:
"""Construct a symmetric Toeplitz matrix using one vector.

Args:
vector: shape [..., L]

Example:
>>> from torchmetrics.functional.audio.sdr import _symmetric_toeplitz
>>> import torch
>>> v = torch.tensor([0, 1, 2, 3, 4])
>>> _symmetric_toeplitz(v)
tensor([[0, 1, 2, 3, 4],
[1, 0, 1, 2, 3],
[2, 1, 0, 1, 2],
[3, 2, 1, 0, 1],
[4, 3, 2, 1, 0]])

Returns:
a symmetric Toeplitz matrix of shape [..., L, L]
"""
vec_exp = torch.cat([torch.flip(vector, dims=(-1,)), vector[..., 1:]], dim=-1)
v_len = vector.shape[-1]
return torch.as_strided(
vec_exp, size=vec_exp.shape[:-1] + (v_len, v_len), stride=vec_exp.stride()[:-1] + (1, 1)
).flip(dims=(-1,))


def _compute_autocorr_crosscorr(
target: torch.Tensor,
preds: torch.Tensor,
L: int,
Borda marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Compute the auto correlation of `target` and the cross correlation of `target` and `preds` using the fast
Fourier transform (FFT). Let's denotes the symmetric Toeplitz matric of the auto correlation of `target` as
`R`, the cross correlation as 'b', then solving the equation `Rh=b` could have `h` as the coordinate of
`preds` in the column space of the L shifts of `target`.

Args:
target: the target (reference) signal of shape [..., time]
preds: the preds (estimated) signal of shape [..., time]
L: the length of the auto correlation and cross correlation

Returns:
the auto correlation of `target` of shape [..., L]
the cross correlation of `target` and `preds` of shape [..., L]
"""
# the valid length for the signal after convolution
n_fft = 2 ** math.ceil(math.log2(preds.shape[-1] + target.shape[-1] - 1))

# computes the auto correlation of `target`
T = torch.fft.rfft(target, n=n_fft, dim=-1)
# R_0 is the first row of the symmetric Toeplitz matric
R_0 = torch.fft.irfft(T.real**2 + T.imag**2, n=n_fft)[..., :L]

# computes the cross-correlation of `target` and `preds`
P = torch.fft.rfft(preds, n=n_fft, dim=-1)
TP = T.conj() * P
b = torch.fft.irfft(TP, n=n_fft, dim=-1)[..., :L]
return R_0, b
Borda marked this conversation as resolved.
Show resolved Hide resolved


def signal_distortion_ratio(
Expand All @@ -56,33 +113,30 @@ def signal_distortion_ratio(
zero_mean: bool = False,
load_diag: Optional[float] = None,
) -> Tensor:
r"""Signal to Distortion Ratio (SDR) [1,2,3]
r"""Signal to Distortion Ratio (SDR) [1,2]

Args:
preds: shape ``[..., time]``
target: shape ``[..., time]``
use_cg_iter:
If provided, an iterative method is used to solve for the distortion filter coefficients instead of direct
Gaussian elimination.
This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here
has been shown to provide good accuracy in most cases and is sufficient when using this loss to train
neural separation networks.
If provided, conjugate gradient descent is used to solve for the distortion
filter coefficients instead of direct Gaussian elimination, which requires that
``fast-bss-eval`` is installed and pytorch version >= 1.8.
This can speed up the computation of the metrics in case the filters
are long. Using a value of 10 here has been shown to provide
good accuracy in most cases and is sufficient when using this
loss to train neural separation networks.
filter_length: The length of the distortion filter allowed
zero_mean: When set to True, the mean of all signals is subtracted prior to computation of the metrics
load_diag:
If provided, this small value is added to the diagonal coefficients of
the system metrics when solving for the filter coefficients.
This can help stabilize the metric in the case where some reference signals may sometimes be zero

Raises:
ModuleNotFoundError:
If ``fast-bss-eval`` package is not installed

Returns:
sdr value of shape ``[...]``

Example:

>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> import torch
>>> g = torch.manual_seed(1)
Expand All @@ -104,28 +158,15 @@ def signal_distortion_ratio(
[0, 1]])

.. note::
1. when pytorch<1.8.0, numpy will be used to calculate this metric, which causes ``sdr`` to be
non-differentiable and slower to calculate

2. using this metrics requires you to have ``fast-bss-eval`` install. Either install as ``pip install
torchmetrics[audio]`` or ``pip install fast-bss-eval``

3. preds and target need to have the same dtype, otherwise target will be converted to preds' dtype
Preds and target need to have the same dtype, otherwise target will be converted to preds' dtype


References:
[1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation.
IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469.

[2] Scheibler, R. (2021). SDR -- Medium Rare with Fast Computations.

[3] https://github.com/fakufaku/fast_bss_eval
"""
if not _FAST_BSS_EVAL_AVAILABLE:
raise ModuleNotFoundError(
"SDR metric requires that `fast-bss-eval` is installed."
" Either install as `pip install torchmetrics[audio]` or `pip install fast-bss-eval`."
)
_check_same_shape(preds, target)

if not preds.dtype.is_floating_point:
Expand All @@ -142,47 +183,43 @@ def signal_distortion_ratio(
preds = preds - preds.mean(dim=-1, keepdim=True)
target = target - target.mean(dim=-1, keepdim=True)

# normalize along time-axis
if not _TORCH_GREATER_EQUAL_1_8:
# use numpy if torch<1.8
rank_zero_warn(
"Pytorch is under 1.8, thus SDR numpy version is used."
"For better performance and differentiability, you should change to Pytorch v1.8 or above."
)
device = preds.device
preds = preds.detach().cpu().numpy()
target = target.detach().cpu().numpy()

preds = _normalize(preds, axis=-1)
target = _normalize(target, axis=-1)
else:
preds = _normalize(preds, dim=-1)
target = _normalize(target, dim=-1)
# normalize along time-axis to make preds and target have unit norm
target = target / torch.clamp(norm(target, axis=-1, keepdims=True), min=1e-6)
preds = preds / torch.clamp(norm(preds, axis=-1, keepdims=True), min=1e-6)

# solve for the optimal filter
# compute auto-correlation and cross-correlation
acf, xcorr = compute_stats(target, preds, length=filter_length, pairwise=False)
R_0, b = _compute_autocorr_crosscorr(target, preds, L=filter_length)

if load_diag is not None:
# the diagonal factor of the Toeplitz matrix is the first
# coefficient of the acf
acf[..., 0] += load_diag
# the diagonal factor of the Toeplitz matrix is the first coefficient of R_0
R_0[..., 0] += load_diag

if use_cg_iter is not None:
if use_cg_iter is not None and _FAST_BSS_EVAL_AVAILABLE and _TORCH_GREATER_EQUAL_1_8:
# use preconditioned conjugate gradient
sol = toeplitz_conjugate_gradient(acf, xcorr, n_iter=use_cg_iter)
sol = toeplitz_conjugate_gradient(R_0, b, n_iter=use_cg_iter)
else:
if use_cg_iter is not None:
if not _FAST_BSS_EVAL_AVAILABLE:
warnings.warn(
"The `use_cg_iter` parameter of `SDR` requires that `fast-bss-eval` is installed. "
"To make this this warning disappear, you could install `fast-bss-eval` using `pip install fast-bss-eval` "
"or set `use_cg_iter=None`. For this time, the solver provided by Pytorch is used.",
UserWarning,
)
elif not _TORCH_GREATER_EQUAL_1_8:
warnings.warn(
"The `use_cg_iter` parameter of `SDR` requires a Pytorch version >= 1.8. "
"To make this this warning disappear, you could change to Pytorch v1.8+ or set `use_cg_iter=None`. "
"For this time, the solver provided by Pytorch is used.",
UserWarning,
)
# regular matrix solver
r_mat = toeplitz(acf)
sol = solve(r_mat, xcorr)

# to tensor if torch<1.8
if not _TORCH_GREATER_EQUAL_1_8:
sol = torch.tensor(sol, device=device)
xcorr = torch.tensor(xcorr, device=device)
R = _symmetric_toeplitz(R_0) # the auto-correlation of the L shifts of `target`
sol = solve(R, b)

# compute the coherence
coh = torch.einsum("...l,...l->...", xcorr, sol)
coh = torch.einsum("...l,...l->...", b, sol)

# transform to decibels
ratio = coh / (1 - coh)
Expand Down