diff --git a/CHANGELOG.md b/CHANGELOG.md index eab50a8ee82..0ab7a66e997 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Reimplemented the `signal_distortion_ratio` metric, which removed the absolute requirement of `fast-bss-eval` ([#964](https://github.com/PyTorchLightning/metrics/pull/964)) - diff --git a/requirements/audio.txt b/requirements/audio.txt index a594d397e84..f0d64dfb017 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -1,4 +1,2 @@ pesq>=0.0.3 pystoi -fast-bss-eval>=0.1.0 -torch_complex # needed for fast-bss-eval torch<=1.7 diff --git a/requirements/audio_test.txt b/requirements/audio_test.txt index 8b9fa30e394..9c37ba7cf40 100644 --- a/requirements/audio_test.txt +++ b/requirements/audio_test.txt @@ -1,3 +1,5 @@ pypesq mir_eval>=0.6 speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip +fast-bss-eval>=0.1.0 +torch_complex # needed for fast-bss-eval torch<=1.7 diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index 7361c596a13..15e31d80664 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -27,7 +27,7 @@ from tests.helpers.testers import MetricTester from torchmetrics.audio import SignalDistortionRatio from torchmetrics.functional import signal_distortion_ratio -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8, _TORCH_LOWER_1_12_DEV +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42) @@ -99,7 +99,6 @@ def test_sdr_functional(self, preds, target, sk_metric): metric_args=dict(), ) - @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_8, reason="sdr is not differentiable for pytorch < 1.8") def test_sdr_differentiability(self, preds, target, sk_metric): self.run_differentiability_test( preds=preds, @@ -155,14 +154,7 @@ def test_too_low_precision(): preds = torch.tensor(data["preds"]) target = torch.tensor(data["target"]) - if _TORCH_GREATER_EQUAL_1_8 and _TORCH_LOWER_1_12_DEV: - with pytest.warns( - UserWarning, - match="Detected `nan` or `inf` value in computed metric, retrying computation in double precision", - ): - sdr_tm = signal_distortion_ratio(preds, target) - else: # when pytorch < 1.8 or pytorch >= 1.12, sdr doesn't have this problem - sdr_tm = signal_distortion_ratio(preds, target).double() + sdr_tm = signal_distortion_ratio(preds, target).double() # check equality with bss_eval_sources in every pytorch version sdr_bss, _, _, _ = bss_eval_sources(target.numpy(), preds.numpy(), False) diff --git a/torchmetrics/audio/sdr.py b/torchmetrics/audio/sdr.py index 8f03cdb0336..ee5fb85c21b 100644 --- a/torchmetrics/audio/sdr.py +++ b/torchmetrics/audio/sdr.py @@ -17,13 +17,12 @@ from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE __doctest_requires__ = {"SignalDistortionRatio": ["fast_bss_eval"]} class SignalDistortionRatio(Metric): - r"""Signal to Distortion Ratio (SDR) [1,2,3] + r"""Signal to Distortion Ratio (SDR) [1,2] Forward accepts @@ -32,10 +31,13 @@ class SignalDistortionRatio(Metric): Args: use_cg_iter: - If provided, an iterative method is used to solve for the distortion filter coefficients instead - of direct Gaussian elimination. This can speed up the computation of the metrics in case the filters - are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient - when using this loss to train neural separation networks. + If provided, conjugate gradient descent is used to solve for the distortion + filter coefficients instead of direct Gaussian elimination, which requires that + ``fast-bss-eval`` is installed and pytorch version >= 1.8. + This can speed up the computation of the metrics in case the filters + are long. Using a value of 10 here has been shown to provide + good accuracy in most cases and is sufficient when using this + loss to train neural separation networks. filter_length: The length of the distortion filter allowed zero_mean: When set to True, the mean of all signals is subtracted prior to computation of the metrics @@ -51,10 +53,6 @@ class SignalDistortionRatio(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - Raises: - ModuleNotFoundError: - If ``fast-bss-eval`` package is not installed - Example: >>> from torchmetrics.audio import SignalDistortionRatio >>> import torch @@ -73,23 +71,11 @@ class SignalDistortionRatio(Metric): >>> pit(preds, target) tensor(-11.6051) - .. note:: - 1. when pytorch<1.8.0, numpy will be used to calculate this metric, which causes ``sdr`` to be - non-differentiable and slower to calculate - - 2. using this metrics requires you to have ``fast-bss-eval`` install. Either install as ``pip install - torchmetrics[audio]`` or ``pip install fast-bss-eval`` - - 3. preds and target need to have the same dtype, otherwise target will be converted to preds' dtype - - References: [1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469. [2] Scheibler, R. (2021). SDR -- Medium Rare with Fast Computations. - - [3] https://github.com/fakufaku/fast_bss_eval """ sum_sdr: Tensor @@ -106,11 +92,6 @@ def __init__( compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any], ) -> None: - if not _FAST_BSS_EVAL_AVAILABLE: - raise ModuleNotFoundError( - "SDR metric requires that `fast-bss-eval` is installed." - " Either install as `pip install torchmetrics[audio]` or `pip install fast-bss-eval`." - ) super().__init__(compute_on_step=compute_on_step, **kwargs) self.use_cg_iter = use_cg_iter diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 366f29e10fc..8a3b9490a9e 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -12,40 +12,111 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import warnings -from typing import Optional +from typing import Optional, Tuple import torch +from torch import Tensor +from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE, _TORCH_GREATER_EQUAL_1_8 -if _FAST_BSS_EVAL_AVAILABLE: - if _TORCH_GREATER_EQUAL_1_8: - from fast_bss_eval.torch.cgd import toeplitz_conjugate_gradient - from fast_bss_eval.torch.helpers import _normalize - from fast_bss_eval.torch.linalg import toeplitz - from fast_bss_eval.torch.metrics import compute_stats +# import or def the norm/solve function +if _TORCH_GREATER_EQUAL_1_8: + from torch.linalg import norm - solve = torch.linalg.solve - else: - import numpy - from fast_bss_eval.numpy.cgd import toeplitz_conjugate_gradient - from fast_bss_eval.numpy.helpers import _normalize - from fast_bss_eval.numpy.linalg import toeplitz - from fast_bss_eval.numpy.metrics import compute_stats + solve = torch.linalg.solve +else: + from torch import norm + from torch import solve as _solve + from torch.nn.functional import pad + + def solve(A: Tensor, b: Tensor) -> Tensor: + return _solve(b[..., None], A)[0][..., 0] - solve = numpy.linalg.solve + +if _FAST_BSS_EVAL_AVAILABLE and _TORCH_GREATER_EQUAL_1_8: + from fast_bss_eval.torch.cgd import toeplitz_conjugate_gradient else: - toeplitz = None toeplitz_conjugate_gradient = None - compute_stats = None - _normalize = None - __doctest_skip__ = ["signal_distortion_ratio"] -from torch import Tensor -from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.checks import _check_same_shape +def _symmetric_toeplitz(vector: Tensor) -> Tensor: + """Construct a symmetric Toeplitz matrix using one vector. + + Args: + vector: shape [..., L] + + Example: + >>> from torchmetrics.functional.audio.sdr import _symmetric_toeplitz + >>> import torch + >>> v = torch.tensor([0, 1, 2, 3, 4]) + >>> _symmetric_toeplitz(v) + tensor([[0, 1, 2, 3, 4], + [1, 0, 1, 2, 3], + [2, 1, 0, 1, 2], + [3, 2, 1, 0, 1], + [4, 3, 2, 1, 0]]) + + Returns: + a symmetric Toeplitz matrix of shape [..., L, L] + """ + vec_exp = torch.cat([torch.flip(vector, dims=(-1,)), vector[..., 1:]], dim=-1) + v_len = vector.shape[-1] + return torch.as_strided( + vec_exp, size=vec_exp.shape[:-1] + (v_len, v_len), stride=vec_exp.stride()[:-1] + (1, 1) + ).flip(dims=(-1,)) + + +def _compute_autocorr_crosscorr( + target: torch.Tensor, + preds: torch.Tensor, + corr_len: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Compute the auto correlation of `target` and the cross correlation of `target` and `preds` using the fast + Fourier transform (FFT). Let's denotes the symmetric Toeplitz matric of the auto correlation of `target` as + `R`, the cross correlation as 'b', then solving the equation `Rh=b` could have `h` as the coordinate of + `preds` in the column space of the `corr_len` shifts of `target`. + + Args: + target: the target (reference) signal of shape [..., time] + preds: the preds (estimated) signal of shape [..., time] + corr_len: the length of the auto correlation and cross correlation + + Returns: + the auto correlation of `target` of shape [..., corr_len] + the cross correlation of `target` and `preds` of shape [..., corr_len] + """ + # the valid length for the signal after convolution + n_fft = 2 ** math.ceil(math.log2(preds.shape[-1] + target.shape[-1] - 1)) + + # computes the auto correlation of `target` + # r_0 is the first row of the symmetric Toeplitz matric + if _TORCH_GREATER_EQUAL_1_8: + t_fft = torch.fft.rfft(target, n=n_fft, dim=-1) + r_0 = torch.fft.irfft(t_fft.real**2 + t_fft.imag**2, n=n_fft)[..., :corr_len] + else: + t_pad = pad(target, (0, n_fft - target.shape[-1]), "constant", 0) + t_fft = torch.rfft(t_pad, signal_ndim=1) + real = t_fft[..., 0] ** 2 + t_fft[..., 1] ** 2 + imag = torch.zeros(real.shape, dtype=real.dtype, device=real.device) + result = torch.stack([real, imag], len(real.shape)) + r_0 = torch.irfft(result, signal_ndim=1, signal_sizes=[n_fft])[..., :corr_len] + + # computes the cross-correlation of `target` and `preds` + if _TORCH_GREATER_EQUAL_1_8: + p_fft = torch.fft.rfft(preds, n=n_fft, dim=-1) + b = torch.fft.irfft(t_fft.conj() * p_fft, n=n_fft, dim=-1)[..., :corr_len] + else: + p_pad = pad(preds, (0, n_fft - preds.shape[-1]), "constant", 0) + p_fft = torch.rfft(p_pad, signal_ndim=1) + real = t_fft[..., 0] * p_fft[..., 0] + t_fft[..., 1] * p_fft[..., 1] + imag = t_fft[..., 0] * p_fft[..., 1] - t_fft[..., 1] * p_fft[..., 0] + result = torch.stack([real, imag], len(real.shape)) + b = torch.irfft(result, signal_ndim=1, signal_sizes=[n_fft])[..., :corr_len] + + return r_0, b def signal_distortion_ratio( @@ -56,17 +127,19 @@ def signal_distortion_ratio( zero_mean: bool = False, load_diag: Optional[float] = None, ) -> Tensor: - r"""Signal to Distortion Ratio (SDR) [1,2,3] + r"""Signal to Distortion Ratio (SDR) [1,2] Args: preds: shape ``[..., time]`` target: shape ``[..., time]`` use_cg_iter: - If provided, an iterative method is used to solve for the distortion filter coefficients instead of direct - Gaussian elimination. - This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here - has been shown to provide good accuracy in most cases and is sufficient when using this loss to train - neural separation networks. + If provided, conjugate gradient descent is used to solve for the distortion + filter coefficients instead of direct Gaussian elimination, which requires that + ``fast-bss-eval`` is installed and pytorch version >= 1.8. + This can speed up the computation of the metrics in case the filters + are long. Using a value of 10 here has been shown to provide + good accuracy in most cases and is sufficient when using this + loss to train neural separation networks. filter_length: The length of the distortion filter allowed zero_mean: When set to True, the mean of all signals is subtracted prior to computation of the metrics load_diag: @@ -74,15 +147,10 @@ def signal_distortion_ratio( the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero - Raises: - ModuleNotFoundError: - If ``fast-bss-eval`` package is not installed - Returns: sdr value of shape ``[...]`` Example: - >>> from torchmetrics.functional.audio import signal_distortion_ratio >>> import torch >>> g = torch.manual_seed(1) @@ -103,100 +171,70 @@ def signal_distortion_ratio( [1, 0], [0, 1]]) - .. note:: - 1. when pytorch<1.8.0, numpy will be used to calculate this metric, which causes ``sdr`` to be - non-differentiable and slower to calculate - - 2. using this metrics requires you to have ``fast-bss-eval`` install. Either install as ``pip install - torchmetrics[audio]`` or ``pip install fast-bss-eval`` - - 3. preds and target need to have the same dtype, otherwise target will be converted to preds' dtype - - References: [1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469. [2] Scheibler, R. (2021). SDR -- Medium Rare with Fast Computations. - - [3] https://github.com/fakufaku/fast_bss_eval """ - if not _FAST_BSS_EVAL_AVAILABLE: - raise ModuleNotFoundError( - "SDR metric requires that `fast-bss-eval` is installed." - " Either install as `pip install torchmetrics[audio]` or `pip install fast-bss-eval`." - ) _check_same_shape(preds, target) - if not preds.dtype.is_floating_point: - preds = preds.float() # for torch.norm - - # half precision support - if preds.dtype == torch.float16: - preds = preds.to(torch.float32) - - if preds.dtype != target.dtype: # for torch.linalg.solve - target = target.to(preds.dtype) + # use double precision + preds_dtype = preds.dtype + preds = preds.double() + target = target.double() if zero_mean: preds = preds - preds.mean(dim=-1, keepdim=True) target = target - target.mean(dim=-1, keepdim=True) - # normalize along time-axis - if not _TORCH_GREATER_EQUAL_1_8: - # use numpy if torch<1.8 - rank_zero_warn( - "Pytorch is under 1.8, thus SDR numpy version is used." - "For better performance and differentiability, you should change to Pytorch v1.8 or above." - ) - device = preds.device - preds = preds.detach().cpu().numpy() - target = target.detach().cpu().numpy() - - preds = _normalize(preds, axis=-1) - target = _normalize(target, axis=-1) - else: - preds = _normalize(preds, dim=-1) - target = _normalize(target, dim=-1) + # normalize along time-axis to make preds and target have unit norm + target = target / torch.clamp(norm(target, dim=-1, keepdim=True), min=1e-6) + preds = preds / torch.clamp(norm(preds, dim=-1, keepdim=True), min=1e-6) # solve for the optimal filter # compute auto-correlation and cross-correlation - acf, xcorr = compute_stats(target, preds, length=filter_length, pairwise=False) + r_0, b = _compute_autocorr_crosscorr(target, preds, corr_len=filter_length) if load_diag is not None: - # the diagonal factor of the Toeplitz matrix is the first - # coefficient of the acf - acf[..., 0] += load_diag + # the diagonal factor of the Toeplitz matrix is the first coefficient of r_0 + r_0[..., 0] += load_diag - if use_cg_iter is not None: + if use_cg_iter is not None and _FAST_BSS_EVAL_AVAILABLE and _TORCH_GREATER_EQUAL_1_8: # use preconditioned conjugate gradient - sol = toeplitz_conjugate_gradient(acf, xcorr, n_iter=use_cg_iter) + sol = toeplitz_conjugate_gradient(r_0, b, n_iter=use_cg_iter) else: + if use_cg_iter is not None: + if not _FAST_BSS_EVAL_AVAILABLE: + warnings.warn( + "The `use_cg_iter` parameter of `SDR` requires that `fast-bss-eval` is installed. " + "To make this this warning disappear, you could install `fast-bss-eval` using " + "`pip install fast-bss-eval` or set `use_cg_iter=None`. For this time, the solver " + "provided by Pytorch is used.", + UserWarning, + ) + elif not _TORCH_GREATER_EQUAL_1_8: + warnings.warn( + "The `use_cg_iter` parameter of `SDR` requires a Pytorch version >= 1.8. " + "To make this this warning disappear, you could change to Pytorch v1.8+ or set `use_cg_iter=None`. " + "For this time, the solver provided by Pytorch is used.", + UserWarning, + ) # regular matrix solver - r_mat = toeplitz(acf) - sol = solve(r_mat, xcorr) - - # to tensor if torch<1.8 - if not _TORCH_GREATER_EQUAL_1_8: - sol = torch.tensor(sol, device=device) - xcorr = torch.tensor(xcorr, device=device) + r = _symmetric_toeplitz(r_0) # the auto-correlation of the L shifts of `target` + sol = solve(r, b) # compute the coherence - coh = torch.einsum("...l,...l->...", xcorr, sol) + coh = torch.einsum("...l,...l->...", b, sol) # transform to decibels ratio = coh / (1 - coh) val = 10.0 * torch.log10(ratio) - # recompute sdr in float64 if val is NaN or Inf - if (torch.isnan(val).any() or torch.isinf(val).any()) and preds.dtype != torch.float64: - warnings.warn( - "Detected `nan` or `inf` value in computed metric, retrying computation in double precision", - UserWarning, - ) - val = signal_distortion_ratio(preds.double(), target.double(), use_cg_iter, filter_length, zero_mean, load_diag) - - return val + if preds_dtype == torch.float64: + return val + else: + return val.float() def scale_invariant_signal_distortion_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: