From eedeeccf1ecb893e073de77a9ae1468c78fbbe99 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 23 Mar 2022 04:04:54 +0000 Subject: [PATCH 01/30] fix sdr in 1.12 --- tests/audio/test_sdr.py | 17 ++++++++++------- torchmetrics/utilities/imports.py | 1 + 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index a03d67cafd4..c37b7fd4638 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -27,7 +27,7 @@ from tests.helpers.testers import MetricTester from torchmetrics.audio import SignalDistortionRatio from torchmetrics.functional import signal_distortion_ratio -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8, _TORCH_LOWER_1_12 seed_all(42) @@ -149,18 +149,21 @@ def test_on_real_audio(): ) -@pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_1_8, reason="when pytorch < 1.8, sdr is using numpy which doesn't have this problem" -) def test_too_low_precision(): """Corner case where the precision of the input is important.""" data = np.load(_SAMPLE_NUMPY_ISSUE_895) preds = torch.tensor(data["preds"]) target = torch.tensor(data["target"]) - with pytest.warns( - UserWarning, match="Detected `nan` or `inf` value in computed metric, retrying computation in double precision" - ): + + if _TORCH_GREATER_EQUAL_1_8 and _TORCH_LOWER_1_12: + with pytest.warns( + UserWarning, match="Detected `nan` or `inf` value in computed metric, retrying computation in double precision" + ): + sdr_tm = signal_distortion_ratio(preds, target) + else: # when pytorch < 1.8 or pytorch > 1.12, sdr doesn't have this problem sdr_tm = signal_distortion_ratio(preds, target) + + # check equality with bss_eval_sources in every pytorch version sdr_bss, _, _, _ = bss_eval_sources(target.numpy(), preds.numpy(), False) assert torch.allclose( sdr_tm.mean(), diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 3f7c2f7a8fc..0f09b687b3a 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -94,6 +94,7 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] _TORCH_LOWER_1_4: Optional[bool] = _compare_version("torch", operator.lt, "1.4.0") _TORCH_LOWER_1_5: Optional[bool] = _compare_version("torch", operator.lt, "1.5.0") _TORCH_LOWER_1_6: Optional[bool] = _compare_version("torch", operator.lt, "1.6.0") +_TORCH_LOWER_1_12: Optional[bool] = _compare_version("torch", operator.lt, "1.12.0") _TORCH_GREATER_EQUAL_1_6: Optional[bool] = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7: Optional[bool] = _compare_version("torch", operator.ge, "1.7.0") _TORCH_GREATER_EQUAL_1_8: Optional[bool] = _compare_version("torch", operator.ge, "1.8.0") From f40f3323060c89d147309a47af755a9c0074dab9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Mar 2022 04:28:11 +0000 Subject: [PATCH 02/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/audio/test_sdr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index c37b7fd4638..ec82af0a90f 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -157,7 +157,8 @@ def test_too_low_precision(): if _TORCH_GREATER_EQUAL_1_8 and _TORCH_LOWER_1_12: with pytest.warns( - UserWarning, match="Detected `nan` or `inf` value in computed metric, retrying computation in double precision" + UserWarning, + match="Detected `nan` or `inf` value in computed metric, retrying computation in double precision", ): sdr_tm = signal_distortion_ratio(preds, target) else: # when pytorch < 1.8 or pytorch > 1.12, sdr doesn't have this problem From 2f5b59ee597d4cec47dbd1d7b6f13d7262d1d78b Mon Sep 17 00:00:00 2001 From: Changsheng Quan Date: Wed, 23 Mar 2022 13:16:54 +0800 Subject: [PATCH 03/30] Update tests/audio/test_sdr.py --- tests/audio/test_sdr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index ec82af0a90f..7bb39c7ea00 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -162,7 +162,7 @@ def test_too_low_precision(): ): sdr_tm = signal_distortion_ratio(preds, target) else: # when pytorch < 1.8 or pytorch > 1.12, sdr doesn't have this problem - sdr_tm = signal_distortion_ratio(preds, target) + sdr_tm = signal_distortion_ratio(preds, target).double() # check equality with bss_eval_sources in every pytorch version sdr_bss, _, _, _ = bss_eval_sources(target.numpy(), preds.numpy(), False) From a7a0d0b380593222d4b34e1572e8228b2d9e6580 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 23 Mar 2022 07:07:09 +0000 Subject: [PATCH 04/30] update --- tests/audio/test_sdr.py | 8 ++++---- torchmetrics/utilities/imports.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index 7bb39c7ea00..7361c596a13 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -27,7 +27,7 @@ from tests.helpers.testers import MetricTester from torchmetrics.audio import SignalDistortionRatio from torchmetrics.functional import signal_distortion_ratio -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8, _TORCH_LOWER_1_12 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8, _TORCH_LOWER_1_12_DEV seed_all(42) @@ -155,13 +155,13 @@ def test_too_low_precision(): preds = torch.tensor(data["preds"]) target = torch.tensor(data["target"]) - if _TORCH_GREATER_EQUAL_1_8 and _TORCH_LOWER_1_12: + if _TORCH_GREATER_EQUAL_1_8 and _TORCH_LOWER_1_12_DEV: with pytest.warns( UserWarning, match="Detected `nan` or `inf` value in computed metric, retrying computation in double precision", ): sdr_tm = signal_distortion_ratio(preds, target) - else: # when pytorch < 1.8 or pytorch > 1.12, sdr doesn't have this problem + else: # when pytorch < 1.8 or pytorch >= 1.12, sdr doesn't have this problem sdr_tm = signal_distortion_ratio(preds, target).double() # check equality with bss_eval_sources in every pytorch version @@ -170,5 +170,5 @@ def test_too_low_precision(): sdr_tm.mean(), torch.tensor(sdr_bss).mean(), rtol=0.0001, - atol=1e-4, + atol=1e-2, ) diff --git a/torchmetrics/utilities/imports.py b/torchmetrics/utilities/imports.py index 0f09b687b3a..9b0f7e41766 100644 --- a/torchmetrics/utilities/imports.py +++ b/torchmetrics/utilities/imports.py @@ -94,7 +94,7 @@ def _compare_version(package: str, op: Callable, version: str) -> Optional[bool] _TORCH_LOWER_1_4: Optional[bool] = _compare_version("torch", operator.lt, "1.4.0") _TORCH_LOWER_1_5: Optional[bool] = _compare_version("torch", operator.lt, "1.5.0") _TORCH_LOWER_1_6: Optional[bool] = _compare_version("torch", operator.lt, "1.6.0") -_TORCH_LOWER_1_12: Optional[bool] = _compare_version("torch", operator.lt, "1.12.0") +_TORCH_LOWER_1_12_DEV: Optional[bool] = _compare_version("torch", operator.lt, "1.12.0.dev") _TORCH_GREATER_EQUAL_1_6: Optional[bool] = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7: Optional[bool] = _compare_version("torch", operator.ge, "1.7.0") _TORCH_GREATER_EQUAL_1_8: Optional[bool] = _compare_version("torch", operator.ge, "1.8.0") From 2899ddc2efde5db485e3028c10ee6c7500fcb55a Mon Sep 17 00:00:00 2001 From: quancs Date: Sun, 17 Apr 2022 12:54:56 +0000 Subject: [PATCH 05/30] reimplement signal_distortion_ratio --- CHANGELOG.md | 2 +- requirements/audio.txt | 2 - requirements/audio_test.txt | 2 + torchmetrics/audio/sdr.py | 27 ++-- torchmetrics/functional/audio/sdr.py | 192 ++++++++++++++++----------- 5 files changed, 126 insertions(+), 99 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c9818244755..1d0a73d33d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Reimplemented the `signal_distortion_ratio` metric, which removed the absolute requirement of `fast-bss-eval`. - diff --git a/requirements/audio.txt b/requirements/audio.txt index a594d397e84..f0d64dfb017 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -1,4 +1,2 @@ pesq>=0.0.3 pystoi -fast-bss-eval>=0.1.0 -torch_complex # needed for fast-bss-eval torch<=1.7 diff --git a/requirements/audio_test.txt b/requirements/audio_test.txt index 8b9fa30e394..9c37ba7cf40 100644 --- a/requirements/audio_test.txt +++ b/requirements/audio_test.txt @@ -1,3 +1,5 @@ pypesq mir_eval>=0.6 speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/master.zip +fast-bss-eval>=0.1.0 +torch_complex # needed for fast-bss-eval torch<=1.7 diff --git a/torchmetrics/audio/sdr.py b/torchmetrics/audio/sdr.py index 8f03cdb0336..c293e86ab38 100644 --- a/torchmetrics/audio/sdr.py +++ b/torchmetrics/audio/sdr.py @@ -23,7 +23,7 @@ class SignalDistortionRatio(Metric): - r"""Signal to Distortion Ratio (SDR) [1,2,3] + r"""Signal to Distortion Ratio (SDR) [1,2] Forward accepts @@ -32,10 +32,13 @@ class SignalDistortionRatio(Metric): Args: use_cg_iter: - If provided, an iterative method is used to solve for the distortion filter coefficients instead - of direct Gaussian elimination. This can speed up the computation of the metrics in case the filters - are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient - when using this loss to train neural separation networks. + If provided, conjugate gradient descent is used to solve for the distortion + filter coefficients instead of direct Gaussian elimination, which requires that + ``fast-bss-eval`` is installed and pytorch version >= 1.8. + This can speed up the computation of the metrics in case the filters + are long. Using a value of 10 here has been shown to provide + good accuracy in most cases and is sufficient when using this + loss to train neural separation networks. filter_length: The length of the distortion filter allowed zero_mean: When set to True, the mean of all signals is subtracted prior to computation of the metrics @@ -51,10 +54,6 @@ class SignalDistortionRatio(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - Raises: - ModuleNotFoundError: - If ``fast-bss-eval`` package is not installed - Example: >>> from torchmetrics.audio import SignalDistortionRatio >>> import torch @@ -74,13 +73,7 @@ class SignalDistortionRatio(Metric): tensor(-11.6051) .. note:: - 1. when pytorch<1.8.0, numpy will be used to calculate this metric, which causes ``sdr`` to be - non-differentiable and slower to calculate - - 2. using this metrics requires you to have ``fast-bss-eval`` install. Either install as ``pip install - torchmetrics[audio]`` or ``pip install fast-bss-eval`` - - 3. preds and target need to have the same dtype, otherwise target will be converted to preds' dtype + 1. preds and target need to have the same dtype, otherwise target will be converted to preds' dtype References: @@ -88,8 +81,6 @@ class SignalDistortionRatio(Metric): IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469. [2] Scheibler, R. (2021). SDR -- Medium Rare with Fast Computations. - - [3] https://github.com/fakufaku/fast_bss_eval """ sum_sdr: Tensor diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 366f29e10fc..2bf8bfdd05f 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -12,40 +12,96 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import warnings -from typing import Optional +from typing import Optional, Tuple import torch +from torch import Tensor -from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE, _TORCH_GREATER_EQUAL_1_8 +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8 -if _FAST_BSS_EVAL_AVAILABLE: - if _TORCH_GREATER_EQUAL_1_8: - from fast_bss_eval.torch.cgd import toeplitz_conjugate_gradient - from fast_bss_eval.torch.helpers import _normalize - from fast_bss_eval.torch.linalg import toeplitz - from fast_bss_eval.torch.metrics import compute_stats +# import or def the norm function +if _TORCH_GREATER_EQUAL_1_7: + from torch.linalg import norm +else: + from torch import norm - solve = torch.linalg.solve - else: - import numpy - from fast_bss_eval.numpy.cgd import toeplitz_conjugate_gradient - from fast_bss_eval.numpy.helpers import _normalize - from fast_bss_eval.numpy.linalg import toeplitz - from fast_bss_eval.numpy.metrics import compute_stats +# import or redirect the solve function +if _TORCH_GREATER_EQUAL_1_8: + solve = torch.linalg.solve +else: + from torch import solve as _solve - solve = numpy.linalg.solve + def solve(A: Tensor, b: Tensor) -> Tensor: + return _solve(b, A)[0] + + +if _FAST_BSS_EVAL_AVAILABLE and _TORCH_GREATER_EQUAL_1_8: + from fast_bss_eval.torch.cgd import toeplitz_conjugate_gradient else: - toeplitz = None toeplitz_conjugate_gradient = None - compute_stats = None - _normalize = None - __doctest_skip__ = ["signal_distortion_ratio"] -from torch import Tensor -from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.checks import _check_same_shape +def _symmetric_toeplitz(v: Tensor) -> Tensor: + """Construct a symmetric Toeplitz matrix using v + + Args: + v: shape [..., L] + + Example: + >>> from torchmetrics.functional.audio.sdr import _symmetric_toeplitz + >>> import torch + >>> v = torch.tensor([0, 1, 2, 3, 4]) + >>> _symmetric_toeplitz(v) + tensor([[0, 1, 2, 3, 4], + [1, 0, 1, 2, 3], + [2, 1, 0, 1, 2], + [3, 2, 1, 0, 1], + [4, 3, 2, 1, 0]]) + + Returns: + a symmetric Toeplitz matrix of shape [..., L, L] + """ + vals = torch.cat([torch.flip(v, dims=(-1,)), v[..., 1:]], dim=-1) + L = v.shape[-1] + return torch.as_strided(vals, size=vals.shape[:-1] + (L, L), stride=vals.stride()[:-1] + (1, 1)).flip(dims=(-1,)) + + +def _compute_autocorr_crosscorr( + target: torch.Tensor, + preds: torch.Tensor, + L: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute the auto correlation of `target` and the cross correlation of `target` and `preds` using \ + the fast Fourier transform (FFT). \ + Let's denotes the symmetric Toeplitz matric of the auto correlation of `target` as `R`, the cross correlation \ + as 'b', then solving the equation `Rh=b` could have `h` as the coordinate of `preds` in the column space of \ + the L shifts of `target`. + + Args: + target: the target (reference) signal of shape [..., time] + preds: the preds (estimated) signal of shape [..., time] + L: the length of the auto correlation and cross correlation + + Returns: + the auto correlation of `target` of shape [..., L] + the cross correlation of `target` and `preds` of shape [..., L] + """ + # the valid length for the signal after convolution + n_fft = 2**math.ceil(math.log2(preds.shape[-1] + target.shape[-1] - 1)) + + # computes the auto correlation of `target` + T = torch.fft.rfft(target, n=n_fft, dim=-1) + # R_0 is the first row of the symmetric Toeplitz matric + R_0 = torch.fft.irfft(T.real**2 + T.imag**2, n=n_fft)[..., :L] + + # computes the cross-correlation of `target` and `preds` + P = torch.fft.rfft(preds, n=n_fft, dim=-1) + TP = T.conj() * P + b = torch.fft.irfft(TP, n=n_fft, dim=-1)[..., :L] + return R_0, b def signal_distortion_ratio( @@ -56,17 +112,19 @@ def signal_distortion_ratio( zero_mean: bool = False, load_diag: Optional[float] = None, ) -> Tensor: - r"""Signal to Distortion Ratio (SDR) [1,2,3] + r"""Signal to Distortion Ratio (SDR) [1,2] Args: preds: shape ``[..., time]`` target: shape ``[..., time]`` use_cg_iter: - If provided, an iterative method is used to solve for the distortion filter coefficients instead of direct - Gaussian elimination. - This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here - has been shown to provide good accuracy in most cases and is sufficient when using this loss to train - neural separation networks. + If provided, conjugate gradient descent is used to solve for the distortion + filter coefficients instead of direct Gaussian elimination, which requires that + ``fast-bss-eval`` is installed and pytorch version >= 1.8. + This can speed up the computation of the metrics in case the filters + are long. Using a value of 10 here has been shown to provide + good accuracy in most cases and is sufficient when using this + loss to train neural separation networks. filter_length: The length of the distortion filter allowed zero_mean: When set to True, the mean of all signals is subtracted prior to computation of the metrics load_diag: @@ -74,15 +132,10 @@ def signal_distortion_ratio( the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero - Raises: - ModuleNotFoundError: - If ``fast-bss-eval`` package is not installed - Returns: sdr value of shape ``[...]`` Example: - >>> from torchmetrics.functional.audio import signal_distortion_ratio >>> import torch >>> g = torch.manual_seed(1) @@ -104,13 +157,7 @@ def signal_distortion_ratio( [0, 1]]) .. note:: - 1. when pytorch<1.8.0, numpy will be used to calculate this metric, which causes ``sdr`` to be - non-differentiable and slower to calculate - - 2. using this metrics requires you to have ``fast-bss-eval`` install. Either install as ``pip install - torchmetrics[audio]`` or ``pip install fast-bss-eval`` - - 3. preds and target need to have the same dtype, otherwise target will be converted to preds' dtype + 1. preds and target need to have the same dtype, otherwise target will be converted to preds' dtype References: @@ -118,14 +165,7 @@ def signal_distortion_ratio( IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469. [2] Scheibler, R. (2021). SDR -- Medium Rare with Fast Computations. - - [3] https://github.com/fakufaku/fast_bss_eval """ - if not _FAST_BSS_EVAL_AVAILABLE: - raise ModuleNotFoundError( - "SDR metric requires that `fast-bss-eval` is installed." - " Either install as `pip install torchmetrics[audio]` or `pip install fast-bss-eval`." - ) _check_same_shape(preds, target) if not preds.dtype.is_floating_point: @@ -142,47 +182,43 @@ def signal_distortion_ratio( preds = preds - preds.mean(dim=-1, keepdim=True) target = target - target.mean(dim=-1, keepdim=True) - # normalize along time-axis - if not _TORCH_GREATER_EQUAL_1_8: - # use numpy if torch<1.8 - rank_zero_warn( - "Pytorch is under 1.8, thus SDR numpy version is used." - "For better performance and differentiability, you should change to Pytorch v1.8 or above." - ) - device = preds.device - preds = preds.detach().cpu().numpy() - target = target.detach().cpu().numpy() - - preds = _normalize(preds, axis=-1) - target = _normalize(target, axis=-1) - else: - preds = _normalize(preds, dim=-1) - target = _normalize(target, dim=-1) + # normalize along time-axis to make preds and target have unit norm + target = target / torch.clamp(norm(target, axis=-1, keepdims=True), min=1e-6) + preds = preds / torch.clamp(norm(preds, axis=-1, keepdims=True), min=1e-6) # solve for the optimal filter # compute auto-correlation and cross-correlation - acf, xcorr = compute_stats(target, preds, length=filter_length, pairwise=False) + R_0, b = _compute_autocorr_crosscorr(target, preds, L=filter_length) if load_diag is not None: - # the diagonal factor of the Toeplitz matrix is the first - # coefficient of the acf - acf[..., 0] += load_diag + # the diagonal factor of the Toeplitz matrix is the first coefficient of R_0 + R_0[..., 0] += load_diag - if use_cg_iter is not None: + if use_cg_iter is not None and _FAST_BSS_EVAL_AVAILABLE and _TORCH_GREATER_EQUAL_1_8: # use preconditioned conjugate gradient - sol = toeplitz_conjugate_gradient(acf, xcorr, n_iter=use_cg_iter) + sol = toeplitz_conjugate_gradient(R_0, b, n_iter=use_cg_iter) else: + if use_cg_iter is not None: + if _FAST_BSS_EVAL_AVAILABLE is False: + warnings.warn( + "The `use_cg_iter` parameter of `SDR` requires that `fast-bss-eval` is installed. " + "To dispear this warning, you could install `fast-bss-eval` using `pip install fast-bss-eval` " + "or set `use_cg_iter=None`. For this time, the solver provided by Pytorch is used.", + UserWarning, + ) + elif _TORCH_GREATER_EQUAL_1_8 is False: + warnings.warn( + "The `use_cg_iter` parameter of `SDR` requires a Pytorch version >= 1.8. " + "To dispear this warning, you could change to Pytorch v1.8+ or set `use_cg_iter=None`. " + "For this time, the solver provided by Pytorch is used.", + UserWarning, + ) # regular matrix solver - r_mat = toeplitz(acf) - sol = solve(r_mat, xcorr) - - # to tensor if torch<1.8 - if not _TORCH_GREATER_EQUAL_1_8: - sol = torch.tensor(sol, device=device) - xcorr = torch.tensor(xcorr, device=device) + R = _symmetric_toeplitz(R_0) # the auto-correlation of the L shifts of `target` + sol = solve(R, b) # compute the coherence - coh = torch.einsum("...l,...l->...", xcorr, sol) + coh = torch.einsum("...l,...l->...", b, sol) # transform to decibels ratio = coh / (1 - coh) From 39b37a2178fb23fd8c16bf9ae921c3432c2eb7fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 17 Apr 2022 13:00:22 +0000 Subject: [PATCH 06/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/sdr.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 2bf8bfdd05f..3c8ea1db3bd 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -45,7 +45,7 @@ def solve(A: Tensor, b: Tensor) -> Tensor: def _symmetric_toeplitz(v: Tensor) -> Tensor: - """Construct a symmetric Toeplitz matrix using v + """Construct a symmetric Toeplitz matrix using v. Args: v: shape [..., L] @@ -74,11 +74,10 @@ def _compute_autocorr_crosscorr( preds: torch.Tensor, L: int, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute the auto correlation of `target` and the cross correlation of `target` and `preds` using \ - the fast Fourier transform (FFT). \ - Let's denotes the symmetric Toeplitz matric of the auto correlation of `target` as `R`, the cross correlation \ - as 'b', then solving the equation `Rh=b` could have `h` as the coordinate of `preds` in the column space of \ - the L shifts of `target`. + """Compute the auto correlation of `target` and the cross correlation of `target` and `preds` using \ the fast + Fourier transform (FFT). \ Let's denotes the symmetric Toeplitz matric of the auto correlation of `target` as + `R`, the cross correlation \ as 'b', then solving the equation `Rh=b` could have `h` as the coordinate of + `preds` in the column space of \ the L shifts of `target`. Args: target: the target (reference) signal of shape [..., time] @@ -90,7 +89,7 @@ def _compute_autocorr_crosscorr( the cross correlation of `target` and `preds` of shape [..., L] """ # the valid length for the signal after convolution - n_fft = 2**math.ceil(math.log2(preds.shape[-1] + target.shape[-1] - 1)) + n_fft = 2 ** math.ceil(math.log2(preds.shape[-1] + target.shape[-1] - 1)) # computes the auto correlation of `target` T = torch.fft.rfft(target, n=n_fft, dim=-1) From 1121751a84f5d1f872edd524683bb2211b3200c6 Mon Sep 17 00:00:00 2001 From: quancs Date: Sun, 17 Apr 2022 13:03:16 +0000 Subject: [PATCH 07/30] sdr is differentiable for all supported pytorch version now --- tests/audio/test_sdr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index 7361c596a13..389d941117c 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -99,7 +99,6 @@ def test_sdr_functional(self, preds, target, sk_metric): metric_args=dict(), ) - @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_8, reason="sdr is not differentiable for pytorch < 1.8") def test_sdr_differentiability(self, preds, target, sk_metric): self.run_differentiability_test( preds=preds, From 95dfb8be222e069fe718fc265dc948817514bcc9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 17 Apr 2022 13:03:59 +0000 Subject: [PATCH 08/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/sdr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 3c8ea1db3bd..d57271ff606 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -74,7 +74,7 @@ def _compute_autocorr_crosscorr( preds: torch.Tensor, L: int, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute the auto correlation of `target` and the cross correlation of `target` and `preds` using \ the fast + r"""Compute the auto correlation of `target` and the cross correlation of `target` and `preds` using \ the fast Fourier transform (FFT). \ Let's denotes the symmetric Toeplitz matric of the auto correlation of `target` as `R`, the cross correlation \ as 'b', then solving the equation `Rh=b` could have `h` as the coordinate of `preds` in the column space of \ the L shifts of `target`. From 027443f4a000e8fb26d00fe7fce51b01669e3b42 Mon Sep 17 00:00:00 2001 From: quancs Date: Sun, 17 Apr 2022 13:06:20 +0000 Subject: [PATCH 09/30] update --- tests/audio/test_sdr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index 389d941117c..441d3751b03 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -27,7 +27,7 @@ from tests.helpers.testers import MetricTester from torchmetrics.audio import SignalDistortionRatio from torchmetrics.functional import signal_distortion_ratio -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_8, _TORCH_LOWER_1_12_DEV +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_LOWER_1_12_DEV seed_all(42) @@ -154,13 +154,13 @@ def test_too_low_precision(): preds = torch.tensor(data["preds"]) target = torch.tensor(data["target"]) - if _TORCH_GREATER_EQUAL_1_8 and _TORCH_LOWER_1_12_DEV: + if _TORCH_LOWER_1_12_DEV: with pytest.warns( UserWarning, match="Detected `nan` or `inf` value in computed metric, retrying computation in double precision", ): sdr_tm = signal_distortion_ratio(preds, target) - else: # when pytorch < 1.8 or pytorch >= 1.12, sdr doesn't have this problem + else: # when pytorch >= 1.12, sdr doesn't have this problem sdr_tm = signal_distortion_ratio(preds, target).double() # check equality with bss_eval_sources in every pytorch version From f6fbcebdd1708b1dab51cd32f26929db480cb56c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 20 Apr 2022 15:27:13 +0200 Subject: [PATCH 10/30] chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ccc2dc2659..94ab1f41c9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- Reimplemented the `signal_distortion_ratio` metric, which removed the absolute requirement of `fast-bss-eval`. +- Reimplemented the `signal_distortion_ratio` metric, which removed the absolute requirement of `fast-bss-eval` ([#964](https://github.com/PyTorchLightning/metrics/pull/964)) - From 833901e2673dc003d36d61eccb56e746463f1411 Mon Sep 17 00:00:00 2001 From: Changsheng Quan Date: Wed, 20 Apr 2022 21:38:18 +0800 Subject: [PATCH 11/30] Apply suggestions from code review Co-authored-by: Jirka Borovec --- torchmetrics/functional/audio/sdr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index d57271ff606..e09728a6625 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -198,14 +198,14 @@ def signal_distortion_ratio( sol = toeplitz_conjugate_gradient(R_0, b, n_iter=use_cg_iter) else: if use_cg_iter is not None: - if _FAST_BSS_EVAL_AVAILABLE is False: + if not _FAST_BSS_EVAL_AVAILABLE: warnings.warn( "The `use_cg_iter` parameter of `SDR` requires that `fast-bss-eval` is installed. " "To dispear this warning, you could install `fast-bss-eval` using `pip install fast-bss-eval` " "or set `use_cg_iter=None`. For this time, the solver provided by Pytorch is used.", UserWarning, ) - elif _TORCH_GREATER_EQUAL_1_8 is False: + elif not _TORCH_GREATER_EQUAL_1_8: warnings.warn( "The `use_cg_iter` parameter of `SDR` requires a Pytorch version >= 1.8. " "To dispear this warning, you could change to Pytorch v1.8+ or set `use_cg_iter=None`. " From 40a846e9d8244de85b570d868a4cc746a6c7c8b4 Mon Sep 17 00:00:00 2001 From: Changsheng Quan Date: Wed, 20 Apr 2022 21:54:08 +0800 Subject: [PATCH 12/30] Apply suggestions from code review Co-authored-by: Nicki Skafte Detlefsen --- torchmetrics/audio/sdr.py | 2 +- torchmetrics/functional/audio/sdr.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torchmetrics/audio/sdr.py b/torchmetrics/audio/sdr.py index c293e86ab38..ecf2da899d5 100644 --- a/torchmetrics/audio/sdr.py +++ b/torchmetrics/audio/sdr.py @@ -73,7 +73,7 @@ class SignalDistortionRatio(Metric): tensor(-11.6051) .. note:: - 1. preds and target need to have the same dtype, otherwise target will be converted to preds' dtype + Preds and target need to have the same dtype, otherwise target will be converted to preds' dtype References: diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index e09728a6625..7835deecd3e 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -74,10 +74,10 @@ def _compute_autocorr_crosscorr( preds: torch.Tensor, L: int, ) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Compute the auto correlation of `target` and the cross correlation of `target` and `preds` using \ the fast - Fourier transform (FFT). \ Let's denotes the symmetric Toeplitz matric of the auto correlation of `target` as - `R`, the cross correlation \ as 'b', then solving the equation `Rh=b` could have `h` as the coordinate of - `preds` in the column space of \ the L shifts of `target`. + r"""Compute the auto correlation of `target` and the cross correlation of `target` and `preds` using the fast + Fourier transform (FFT). Let's denotes the symmetric Toeplitz matric of the auto correlation of `target` as + `R`, the cross correlation as 'b', then solving the equation `Rh=b` could have `h` as the coordinate of + `preds` in the column space of the L shifts of `target`. Args: target: the target (reference) signal of shape [..., time] @@ -156,7 +156,7 @@ def signal_distortion_ratio( [0, 1]]) .. note:: - 1. preds and target need to have the same dtype, otherwise target will be converted to preds' dtype + Preds and target need to have the same dtype, otherwise target will be converted to preds' dtype References: @@ -201,14 +201,14 @@ def signal_distortion_ratio( if not _FAST_BSS_EVAL_AVAILABLE: warnings.warn( "The `use_cg_iter` parameter of `SDR` requires that `fast-bss-eval` is installed. " - "To dispear this warning, you could install `fast-bss-eval` using `pip install fast-bss-eval` " + "To make this this warning disappear, you could install `fast-bss-eval` using `pip install fast-bss-eval` " "or set `use_cg_iter=None`. For this time, the solver provided by Pytorch is used.", UserWarning, ) elif not _TORCH_GREATER_EQUAL_1_8: warnings.warn( "The `use_cg_iter` parameter of `SDR` requires a Pytorch version >= 1.8. " - "To dispear this warning, you could change to Pytorch v1.8+ or set `use_cg_iter=None`. " + "To make this this warning disappear, you could change to Pytorch v1.8+ or set `use_cg_iter=None`. " "For this time, the solver provided by Pytorch is used.", UserWarning, ) From 21ae77afe22c4fb5e26d09cc93827c29e26e1d7b Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 20 Apr 2022 13:58:58 +0000 Subject: [PATCH 13/30] rename --- torchmetrics/functional/audio/sdr.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 7835deecd3e..7f789085610 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -44,11 +44,11 @@ def solve(A: Tensor, b: Tensor) -> Tensor: toeplitz_conjugate_gradient = None -def _symmetric_toeplitz(v: Tensor) -> Tensor: - """Construct a symmetric Toeplitz matrix using v. +def _symmetric_toeplitz(vector: Tensor) -> Tensor: + """Construct a symmetric Toeplitz matrix using one vector. Args: - v: shape [..., L] + vector: shape [..., L] Example: >>> from torchmetrics.functional.audio.sdr import _symmetric_toeplitz @@ -64,9 +64,9 @@ def _symmetric_toeplitz(v: Tensor) -> Tensor: Returns: a symmetric Toeplitz matrix of shape [..., L, L] """ - vals = torch.cat([torch.flip(v, dims=(-1,)), v[..., 1:]], dim=-1) - L = v.shape[-1] - return torch.as_strided(vals, size=vals.shape[:-1] + (L, L), stride=vals.stride()[:-1] + (1, 1)).flip(dims=(-1,)) + vec_exp = torch.cat([torch.flip(vector, dims=(-1,)), vector[..., 1:]], dim=-1) + v_len = vector.shape[-1] + return torch.as_strided(vec_exp, size=vec_exp.shape[:-1] + (v_len, v_len), stride=vec_exp.stride()[:-1] + (1, 1)).flip(dims=(-1,)) def _compute_autocorr_crosscorr( From 111467253e011e586a91aadf591108748edcda55 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 20 Apr 2022 14:00:14 +0000 Subject: [PATCH 14/30] remove _FAST_BSS_EVAL_AVAILABLE --- torchmetrics/audio/sdr.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torchmetrics/audio/sdr.py b/torchmetrics/audio/sdr.py index ecf2da899d5..5464cdbbca3 100644 --- a/torchmetrics/audio/sdr.py +++ b/torchmetrics/audio/sdr.py @@ -17,7 +17,6 @@ from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE __doctest_requires__ = {"SignalDistortionRatio": ["fast_bss_eval"]} @@ -97,11 +96,6 @@ def __init__( compute_on_step: Optional[bool] = None, **kwargs: Dict[str, Any], ) -> None: - if not _FAST_BSS_EVAL_AVAILABLE: - raise ModuleNotFoundError( - "SDR metric requires that `fast-bss-eval` is installed." - " Either install as `pip install torchmetrics[audio]` or `pip install fast-bss-eval`." - ) super().__init__(compute_on_step=compute_on_step, **kwargs) self.use_cg_iter = use_cg_iter From 7e4a595f8a7ac6633a3dded8be5082906a935d70 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:01:12 +0000 Subject: [PATCH 15/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/sdr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 7f789085610..7dfa3f67581 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -66,7 +66,9 @@ def _symmetric_toeplitz(vector: Tensor) -> Tensor: """ vec_exp = torch.cat([torch.flip(vector, dims=(-1,)), vector[..., 1:]], dim=-1) v_len = vector.shape[-1] - return torch.as_strided(vec_exp, size=vec_exp.shape[:-1] + (v_len, v_len), stride=vec_exp.stride()[:-1] + (1, 1)).flip(dims=(-1,)) + return torch.as_strided( + vec_exp, size=vec_exp.shape[:-1] + (v_len, v_len), stride=vec_exp.stride()[:-1] + (1, 1) + ).flip(dims=(-1,)) def _compute_autocorr_crosscorr( From 94ad23b95ae653445a93ee47e9d8ed5618ca5f61 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 20 Apr 2022 16:07:32 +0200 Subject: [PATCH 16/30] fix too long line --- torchmetrics/functional/audio/sdr.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 7dfa3f67581..f1a802f6f29 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -203,8 +203,9 @@ def signal_distortion_ratio( if not _FAST_BSS_EVAL_AVAILABLE: warnings.warn( "The `use_cg_iter` parameter of `SDR` requires that `fast-bss-eval` is installed. " - "To make this this warning disappear, you could install `fast-bss-eval` using `pip install fast-bss-eval` " - "or set `use_cg_iter=None`. For this time, the solver provided by Pytorch is used.", + "To make this this warning disappear, you could install `fast-bss-eval` using " + "`pip install fast-bss-eval` or set `use_cg_iter=None`. For this time, the solver " + "provided by Pytorch is used.", UserWarning, ) elif not _TORCH_GREATER_EQUAL_1_8: From 54418c19c3d5475be8154e8ae8da63215bc6bba5 Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 20 Apr 2022 14:34:31 +0000 Subject: [PATCH 17/30] rename --- torchmetrics/functional/audio/sdr.py | 35 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index f1a802f6f29..fc772eab688 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -74,35 +74,34 @@ def _symmetric_toeplitz(vector: Tensor) -> Tensor: def _compute_autocorr_crosscorr( target: torch.Tensor, preds: torch.Tensor, - L: int, + corr_len: int, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Compute the auto correlation of `target` and the cross correlation of `target` and `preds` using the fast Fourier transform (FFT). Let's denotes the symmetric Toeplitz matric of the auto correlation of `target` as `R`, the cross correlation as 'b', then solving the equation `Rh=b` could have `h` as the coordinate of - `preds` in the column space of the L shifts of `target`. + `preds` in the column space of the `corr_len` shifts of `target`. Args: target: the target (reference) signal of shape [..., time] preds: the preds (estimated) signal of shape [..., time] - L: the length of the auto correlation and cross correlation + corr_len: the length of the auto correlation and cross correlation Returns: - the auto correlation of `target` of shape [..., L] - the cross correlation of `target` and `preds` of shape [..., L] + the auto correlation of `target` of shape [..., corr_len] + the cross correlation of `target` and `preds` of shape [..., corr_len] """ # the valid length for the signal after convolution n_fft = 2 ** math.ceil(math.log2(preds.shape[-1] + target.shape[-1] - 1)) # computes the auto correlation of `target` - T = torch.fft.rfft(target, n=n_fft, dim=-1) - # R_0 is the first row of the symmetric Toeplitz matric - R_0 = torch.fft.irfft(T.real**2 + T.imag**2, n=n_fft)[..., :L] + t_fft = torch.fft.rfft(target, n=n_fft, dim=-1) + # r_0 is the first row of the symmetric Toeplitz matric + r_0 = torch.fft.irfft(t_fft.real**2 + t_fft.imag**2, n=n_fft)[..., :corr_len] # computes the cross-correlation of `target` and `preds` - P = torch.fft.rfft(preds, n=n_fft, dim=-1) - TP = T.conj() * P - b = torch.fft.irfft(TP, n=n_fft, dim=-1)[..., :L] - return R_0, b + p_fft = torch.fft.rfft(preds, n=n_fft, dim=-1) + b = torch.fft.irfft(t_fft.conj() * p_fft, n=n_fft, dim=-1)[..., :corr_len] + return r_0, b def signal_distortion_ratio( @@ -189,15 +188,15 @@ def signal_distortion_ratio( # solve for the optimal filter # compute auto-correlation and cross-correlation - R_0, b = _compute_autocorr_crosscorr(target, preds, L=filter_length) + r_0, b = _compute_autocorr_crosscorr(target, preds, corr_len=filter_length) if load_diag is not None: - # the diagonal factor of the Toeplitz matrix is the first coefficient of R_0 - R_0[..., 0] += load_diag + # the diagonal factor of the Toeplitz matrix is the first coefficient of r_0 + r_0[..., 0] += load_diag if use_cg_iter is not None and _FAST_BSS_EVAL_AVAILABLE and _TORCH_GREATER_EQUAL_1_8: # use preconditioned conjugate gradient - sol = toeplitz_conjugate_gradient(R_0, b, n_iter=use_cg_iter) + sol = toeplitz_conjugate_gradient(r_0, b, n_iter=use_cg_iter) else: if use_cg_iter is not None: if not _FAST_BSS_EVAL_AVAILABLE: @@ -216,8 +215,8 @@ def signal_distortion_ratio( UserWarning, ) # regular matrix solver - R = _symmetric_toeplitz(R_0) # the auto-correlation of the L shifts of `target` - sol = solve(R, b) + r = _symmetric_toeplitz(r_0) # the auto-correlation of the L shifts of `target` + sol = solve(r, b) # compute the coherence coh = torch.einsum("...l,...l->...", b, sol) From af9ccac397a1fb8c8a4da09cf2546f38f97ef4fd Mon Sep 17 00:00:00 2001 From: quancs Date: Wed, 20 Apr 2022 14:40:25 +0000 Subject: [PATCH 18/30] fix --- torchmetrics/functional/audio/sdr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index fc772eab688..9c303b41865 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -183,8 +183,8 @@ def signal_distortion_ratio( target = target - target.mean(dim=-1, keepdim=True) # normalize along time-axis to make preds and target have unit norm - target = target / torch.clamp(norm(target, axis=-1, keepdims=True), min=1e-6) - preds = preds / torch.clamp(norm(preds, axis=-1, keepdims=True), min=1e-6) + target = target / torch.clamp(norm(target, dim=-1, keepdims=True), min=1e-6) + preds = preds / torch.clamp(norm(preds, dim=-1, keepdims=True), min=1e-6) # solve for the optimal filter # compute auto-correlation and cross-correlation From 436cd6497c7a557d1dbdf117b0646ecdba2e29c2 Mon Sep 17 00:00:00 2001 From: quancs Date: Thu, 21 Apr 2022 02:11:28 +0000 Subject: [PATCH 19/30] fft --- torchmetrics/functional/audio/sdr.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 9c303b41865..2e796dcbe3f 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -94,13 +94,28 @@ def _compute_autocorr_crosscorr( n_fft = 2 ** math.ceil(math.log2(preds.shape[-1] + target.shape[-1] - 1)) # computes the auto correlation of `target` - t_fft = torch.fft.rfft(target, n=n_fft, dim=-1) # r_0 is the first row of the symmetric Toeplitz matric - r_0 = torch.fft.irfft(t_fft.real**2 + t_fft.imag**2, n=n_fft)[..., :corr_len] + if _TORCH_GREATER_EQUAL_1_7: + t_fft = torch.fft.rfft(target, n=n_fft, dim=-1) + r_0 = torch.fft.irfft(t_fft.real**2 + t_fft.imag**2, n=n_fft)[..., :corr_len] + else: + t_fft = torch.rfft(target, signal_ndim=1) + real = t_fft[..., 0]**2 + t_fft[..., 1]**2 + imag = torch.zeros(real.shape, dtype=real.dtype, device=real.device) + result = torch.stack([real, imag], len(real.shape)) + r_0 = torch.irfft(result, signal_ndim=1)[..., :corr_len] # computes the cross-correlation of `target` and `preds` - p_fft = torch.fft.rfft(preds, n=n_fft, dim=-1) - b = torch.fft.irfft(t_fft.conj() * p_fft, n=n_fft, dim=-1)[..., :corr_len] + if _TORCH_GREATER_EQUAL_1_7: + p_fft = torch.fft.rfft(preds, n=n_fft, dim=-1) + b = torch.fft.irfft(t_fft.conj() * p_fft, n=n_fft, dim=-1)[..., :corr_len] + else: + p_fft = torch.rfft(preds, signal_ndim=1) + real = t_fft[..., 0] * p_fft[..., 0] + t_fft[..., 1] * p_fft[..., 1] + imag = t_fft[..., 0] * p_fft[..., 1] - t_fft[..., 1] * p_fft[..., 0] + result = torch.stack([real, imag], len(real.shape)) + b = torch.irfft(result, signal_ndim=1)[..., :corr_len] + return r_0, b From 18654bb3c56c8df162d23c2d71bd4c64cea623b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Apr 2022 02:12:11 +0000 Subject: [PATCH 20/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/sdr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 2e796dcbe3f..356ce809eb4 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -100,7 +100,7 @@ def _compute_autocorr_crosscorr( r_0 = torch.fft.irfft(t_fft.real**2 + t_fft.imag**2, n=n_fft)[..., :corr_len] else: t_fft = torch.rfft(target, signal_ndim=1) - real = t_fft[..., 0]**2 + t_fft[..., 1]**2 + real = t_fft[..., 0] ** 2 + t_fft[..., 1] ** 2 imag = torch.zeros(real.shape, dtype=real.dtype, device=real.device) result = torch.stack([real, imag], len(real.shape)) r_0 = torch.irfft(result, signal_ndim=1)[..., :corr_len] From 38ec9317daba229b860d355cca92344fe0c7a983 Mon Sep 17 00:00:00 2001 From: quancs Date: Thu, 21 Apr 2022 03:01:05 +0000 Subject: [PATCH 21/30] fix keepdim --- torchmetrics/functional/audio/sdr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 356ce809eb4..329314546ad 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -198,8 +198,8 @@ def signal_distortion_ratio( target = target - target.mean(dim=-1, keepdim=True) # normalize along time-axis to make preds and target have unit norm - target = target / torch.clamp(norm(target, dim=-1, keepdims=True), min=1e-6) - preds = preds / torch.clamp(norm(preds, dim=-1, keepdims=True), min=1e-6) + target = target / torch.clamp(norm(target, dim=-1, keepdim=True), min=1e-6) + preds = preds / torch.clamp(norm(preds, dim=-1, keepdim=True), min=1e-6) # solve for the optimal filter # compute auto-correlation and cross-correlation From 3a19194b99b1546422d8a30d74ab249bf861b071 Mon Sep 17 00:00:00 2001 From: quancs Date: Thu, 21 Apr 2022 08:19:03 +0000 Subject: [PATCH 22/30] fix solve --- torchmetrics/functional/audio/sdr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 329314546ad..048e51fe60a 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -35,7 +35,7 @@ from torch import solve as _solve def solve(A: Tensor, b: Tensor) -> Tensor: - return _solve(b, A)[0] + return _solve(b[...,None], A)[0][...,0] if _FAST_BSS_EVAL_AVAILABLE and _TORCH_GREATER_EQUAL_1_8: From b15743b448914ee97f18ab7201acb69f906f87e8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Apr 2022 08:19:57 +0000 Subject: [PATCH 23/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/sdr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 048e51fe60a..ce2b6bd8c86 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -35,7 +35,7 @@ from torch import solve as _solve def solve(A: Tensor, b: Tensor) -> Tensor: - return _solve(b[...,None], A)[0][...,0] + return _solve(b[..., None], A)[0][..., 0] if _FAST_BSS_EVAL_AVAILABLE and _TORCH_GREATER_EQUAL_1_8: From a213c19490a01dec90bcc7f3de7c9fdf30b8623e Mon Sep 17 00:00:00 2001 From: quancs Date: Thu, 21 Apr 2022 08:20:36 +0000 Subject: [PATCH 24/30] format --- torchmetrics/functional/audio/sdr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 048e51fe60a..ce2b6bd8c86 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -35,7 +35,7 @@ from torch import solve as _solve def solve(A: Tensor, b: Tensor) -> Tensor: - return _solve(b[...,None], A)[0][...,0] + return _solve(b[..., None], A)[0][..., 0] if _FAST_BSS_EVAL_AVAILABLE and _TORCH_GREATER_EQUAL_1_8: From c433a6087cac6be83c93daa21e4ffa1a59f5541d Mon Sep 17 00:00:00 2001 From: quancs Date: Thu, 21 Apr 2022 10:07:19 +0000 Subject: [PATCH 25/30] use double precision --- tests/audio/test_sdr.py | 9 +-------- torchmetrics/audio/sdr.py | 2 +- torchmetrics/functional/audio/sdr.py | 25 +++++-------------------- 3 files changed, 7 insertions(+), 29 deletions(-) diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index 441d3751b03..7c8f1a15f12 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -154,14 +154,7 @@ def test_too_low_precision(): preds = torch.tensor(data["preds"]) target = torch.tensor(data["target"]) - if _TORCH_LOWER_1_12_DEV: - with pytest.warns( - UserWarning, - match="Detected `nan` or `inf` value in computed metric, retrying computation in double precision", - ): - sdr_tm = signal_distortion_ratio(preds, target) - else: # when pytorch >= 1.12, sdr doesn't have this problem - sdr_tm = signal_distortion_ratio(preds, target).double() + sdr_tm = signal_distortion_ratio(preds, target).double() # check equality with bss_eval_sources in every pytorch version sdr_bss, _, _, _ = bss_eval_sources(target.numpy(), preds.numpy(), False) diff --git a/torchmetrics/audio/sdr.py b/torchmetrics/audio/sdr.py index 5464cdbbca3..a3e613bd559 100644 --- a/torchmetrics/audio/sdr.py +++ b/torchmetrics/audio/sdr.py @@ -72,7 +72,7 @@ class SignalDistortionRatio(Metric): tensor(-11.6051) .. note:: - Preds and target need to have the same dtype, otherwise target will be converted to preds' dtype + Preds and target are converted to double precision in signal_distortion_ratio References: diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index ce2b6bd8c86..2a95a345f87 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -172,7 +172,7 @@ def signal_distortion_ratio( [0, 1]]) .. note:: - Preds and target need to have the same dtype, otherwise target will be converted to preds' dtype + Preds and target are converted to double precision in signal_distortion_ratio References: @@ -182,16 +182,10 @@ def signal_distortion_ratio( [2] Scheibler, R. (2021). SDR -- Medium Rare with Fast Computations. """ _check_same_shape(preds, target) - - if not preds.dtype.is_floating_point: - preds = preds.float() # for torch.norm - - # half precision support - if preds.dtype == torch.float16: - preds = preds.to(torch.float32) - - if preds.dtype != target.dtype: # for torch.linalg.solve - target = target.to(preds.dtype) + + # use double precision + preds = preds.double() + target = target.double() if zero_mean: preds = preds - preds.mean(dim=-1, keepdim=True) @@ -239,15 +233,6 @@ def signal_distortion_ratio( # transform to decibels ratio = coh / (1 - coh) val = 10.0 * torch.log10(ratio) - - # recompute sdr in float64 if val is NaN or Inf - if (torch.isnan(val).any() or torch.isinf(val).any()) and preds.dtype != torch.float64: - warnings.warn( - "Detected `nan` or `inf` value in computed metric, retrying computation in double precision", - UserWarning, - ) - val = signal_distortion_ratio(preds.double(), target.double(), use_cg_iter, filter_length, zero_mean, load_diag) - return val From e847d16a515e23e79c98f72bfaf9374d964f5792 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Apr 2022 10:07:58 +0000 Subject: [PATCH 26/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/sdr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 2a95a345f87..ba7aa2050c0 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -182,7 +182,7 @@ def signal_distortion_ratio( [2] Scheibler, R. (2021). SDR -- Medium Rare with Fast Computations. """ _check_same_shape(preds, target) - + # use double precision preds = preds.double() target = target.double() From 8e7cf5f8315ff4498c857d2e74e4d1c5e83bce97 Mon Sep 17 00:00:00 2001 From: quancs Date: Thu, 21 Apr 2022 10:08:32 +0000 Subject: [PATCH 27/30] remove _TORCH_GREATER_EQUAL_1_7 --- torchmetrics/functional/audio/sdr.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index 2a95a345f87..c2d84006453 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -20,10 +20,10 @@ from torch import Tensor from torchmetrics.utilities.checks import _check_same_shape -from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8 +from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE, _TORCH_GREATER_EQUAL_1_8 # import or def the norm function -if _TORCH_GREATER_EQUAL_1_7: +if _TORCH_GREATER_EQUAL_1_8: from torch.linalg import norm else: from torch import norm @@ -95,7 +95,7 @@ def _compute_autocorr_crosscorr( # computes the auto correlation of `target` # r_0 is the first row of the symmetric Toeplitz matric - if _TORCH_GREATER_EQUAL_1_7: + if _TORCH_GREATER_EQUAL_1_8: t_fft = torch.fft.rfft(target, n=n_fft, dim=-1) r_0 = torch.fft.irfft(t_fft.real**2 + t_fft.imag**2, n=n_fft)[..., :corr_len] else: @@ -106,7 +106,7 @@ def _compute_autocorr_crosscorr( r_0 = torch.irfft(result, signal_ndim=1)[..., :corr_len] # computes the cross-correlation of `target` and `preds` - if _TORCH_GREATER_EQUAL_1_7: + if _TORCH_GREATER_EQUAL_1_8: p_fft = torch.fft.rfft(preds, n=n_fft, dim=-1) b = torch.fft.irfft(t_fft.conj() * p_fft, n=n_fft, dim=-1)[..., :corr_len] else: @@ -182,7 +182,7 @@ def signal_distortion_ratio( [2] Scheibler, R. (2021). SDR -- Medium Rare with Fast Computations. """ _check_same_shape(preds, target) - + # use double precision preds = preds.double() target = target.double() From 7732764c6782dcc85fd0722df241ef198518ec72 Mon Sep 17 00:00:00 2001 From: quancs Date: Thu, 21 Apr 2022 13:22:17 +0000 Subject: [PATCH 28/30] fix rfft/irfft --- torchmetrics/audio/sdr.py | 4 ---- torchmetrics/functional/audio/sdr.py | 31 ++++++++++++++-------------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/torchmetrics/audio/sdr.py b/torchmetrics/audio/sdr.py index a3e613bd559..ee5fb85c21b 100644 --- a/torchmetrics/audio/sdr.py +++ b/torchmetrics/audio/sdr.py @@ -71,10 +71,6 @@ class SignalDistortionRatio(Metric): >>> pit(preds, target) tensor(-11.6051) - .. note:: - Preds and target are converted to double precision in signal_distortion_ratio - - References: [1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469. diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index c2d84006453..fe19489235a 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -22,22 +22,18 @@ from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.imports import _FAST_BSS_EVAL_AVAILABLE, _TORCH_GREATER_EQUAL_1_8 -# import or def the norm function +# import or def the norm/solve function if _TORCH_GREATER_EQUAL_1_8: from torch.linalg import norm -else: - from torch import norm - -# import or redirect the solve function -if _TORCH_GREATER_EQUAL_1_8: solve = torch.linalg.solve else: + from torch import norm + from torch.nn.functional import pad from torch import solve as _solve def solve(A: Tensor, b: Tensor) -> Tensor: return _solve(b[..., None], A)[0][..., 0] - if _FAST_BSS_EVAL_AVAILABLE and _TORCH_GREATER_EQUAL_1_8: from fast_bss_eval.torch.cgd import toeplitz_conjugate_gradient else: @@ -99,22 +95,24 @@ def _compute_autocorr_crosscorr( t_fft = torch.fft.rfft(target, n=n_fft, dim=-1) r_0 = torch.fft.irfft(t_fft.real**2 + t_fft.imag**2, n=n_fft)[..., :corr_len] else: - t_fft = torch.rfft(target, signal_ndim=1) + t_pad = pad(target, (0, n_fft - target.shape[-1]), "constant", 0) + t_fft = torch.rfft(t_pad, signal_ndim=1) real = t_fft[..., 0] ** 2 + t_fft[..., 1] ** 2 imag = torch.zeros(real.shape, dtype=real.dtype, device=real.device) result = torch.stack([real, imag], len(real.shape)) - r_0 = torch.irfft(result, signal_ndim=1)[..., :corr_len] + r_0 = torch.irfft(result, signal_ndim=1, signal_sizes=[n_fft])[..., :corr_len] # computes the cross-correlation of `target` and `preds` if _TORCH_GREATER_EQUAL_1_8: p_fft = torch.fft.rfft(preds, n=n_fft, dim=-1) b = torch.fft.irfft(t_fft.conj() * p_fft, n=n_fft, dim=-1)[..., :corr_len] else: - p_fft = torch.rfft(preds, signal_ndim=1) + p_pad = pad(preds, (0, n_fft - preds.shape[-1]), "constant", 0) + p_fft = torch.rfft(p_pad, signal_ndim=1) real = t_fft[..., 0] * p_fft[..., 0] + t_fft[..., 1] * p_fft[..., 1] imag = t_fft[..., 0] * p_fft[..., 1] - t_fft[..., 1] * p_fft[..., 0] result = torch.stack([real, imag], len(real.shape)) - b = torch.irfft(result, signal_ndim=1)[..., :corr_len] + b = torch.irfft(result, signal_ndim=1, signal_sizes=[n_fft])[..., :corr_len] return r_0, b @@ -171,10 +169,6 @@ def signal_distortion_ratio( [1, 0], [0, 1]]) - .. note:: - Preds and target are converted to double precision in signal_distortion_ratio - - References: [1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469. @@ -184,6 +178,7 @@ def signal_distortion_ratio( _check_same_shape(preds, target) # use double precision + preds_dtype = preds.dtype preds = preds.double() target = target.double() @@ -233,7 +228,11 @@ def signal_distortion_ratio( # transform to decibels ratio = coh / (1 - coh) val = 10.0 * torch.log10(ratio) - return val + + if preds_dtype == torch.float64: + return val + else: + return val.float() def scale_invariant_signal_distortion_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor: From 9d91d3b4319def6d85897802a17db7720378239c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Apr 2022 13:31:35 +0000 Subject: [PATCH 29/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/audio/sdr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmetrics/functional/audio/sdr.py b/torchmetrics/functional/audio/sdr.py index fe19489235a..8a3b9490a9e 100644 --- a/torchmetrics/functional/audio/sdr.py +++ b/torchmetrics/functional/audio/sdr.py @@ -25,15 +25,17 @@ # import or def the norm/solve function if _TORCH_GREATER_EQUAL_1_8: from torch.linalg import norm + solve = torch.linalg.solve else: from torch import norm - from torch.nn.functional import pad from torch import solve as _solve + from torch.nn.functional import pad def solve(A: Tensor, b: Tensor) -> Tensor: return _solve(b[..., None], A)[0][..., 0] + if _FAST_BSS_EVAL_AVAILABLE and _TORCH_GREATER_EQUAL_1_8: from fast_bss_eval.torch.cgd import toeplitz_conjugate_gradient else: From 9343d5d4f17e14bf9090f88e6ba932bcd88e1521 Mon Sep 17 00:00:00 2001 From: quancs Date: Thu, 21 Apr 2022 13:37:15 +0000 Subject: [PATCH 30/30] remove _TORCH_LOWER_1_12_DEV --- tests/audio/test_sdr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/audio/test_sdr.py b/tests/audio/test_sdr.py index 7c8f1a15f12..15e31d80664 100644 --- a/tests/audio/test_sdr.py +++ b/tests/audio/test_sdr.py @@ -27,7 +27,7 @@ from tests.helpers.testers import MetricTester from torchmetrics.audio import SignalDistortionRatio from torchmetrics.functional import signal_distortion_ratio -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_LOWER_1_12_DEV +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 seed_all(42)