diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 056cf3a0d8..2bb19e2c86 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -197,6 +197,13 @@ Transforms are common audio transforms. They can be chained together using :clas .. automethod:: forward +:hidden:`SoudenMVDR` +-------------------- + +.. autoclass:: SoudenMVDR + + .. automethod:: forward + References ~~~~~~~~~~ diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index cafcd3f581..7dc587f5df 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -303,6 +303,16 @@ def test_mvdr(self, solution): mask_n = torch.rand(spectrogram.shape[-2:]) self.assert_grad(transform, [spectrogram, mask_s, mask_n]) + def test_souden_mvdr(self): + transform = T.SoudenMVDR() + waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) + specgram = get_spectrogram(waveform, n_fft=400) + channel, freq, _ = specgram.shape + psd_s = torch.rand(freq, channel, channel, dtype=torch.cfloat) + psd_n = torch.rand(freq, channel, channel, dtype=torch.cfloat) + reference_channel = 0 + self.assert_grad(transform, [specgram, psd_s, psd_n, reference_channel]) + class AutogradTestFloat32(TestBaseMixin): def assert_grad( diff --git a/test/torchaudio_unittest/transforms/batch_consistency_test.py b/test/torchaudio_unittest/transforms/batch_consistency_test.py index 67c32419f7..1657e340e7 100644 --- a/test/torchaudio_unittest/transforms/batch_consistency_test.py +++ b/test/torchaudio_unittest/transforms/batch_consistency_test.py @@ -219,3 +219,22 @@ def test_MVDR(self, multi_mask): computed = transform(specgram, mask_s, mask_n) self.assertEqual(computed, expected) + + def test_souden_mvdr(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + specgram = common_utils.get_spectrogram(waveform, n_fft=400) + batch_size, channel, freq, time = 3, 2, specgram.shape[-2], specgram.shape[-1] + specgram = specgram.reshape(batch_size, channel, freq, time) + psd_s = torch.rand(batch_size, freq, channel, channel, dtype=torch.cfloat) + psd_n = torch.rand(batch_size, freq, channel, channel, dtype=torch.cfloat) + reference_channel = 0 + transform = T.SoudenMVDR() + + # Single then transform then batch + expected = [transform(specgram[i], psd_s[i], psd_n[i], reference_channel) for i in range(batch_size)] + expected = torch.stack(expected) + + # Batch then transform + computed = transform(specgram, psd_s, psd_n, reference_channel) + + self.assertEqual(computed, expected) diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 04be28db7c..704c1c6bd5 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -175,6 +175,15 @@ def test_MVDR(self, solution, online): mask_n = torch.rand(spectrogram.shape[-2:], device=self.device) self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n) + def test_souden_mvdr(self): + tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4) + specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) + channel, freq, _ = specgram.shape + psd_s = torch.rand(freq, channel, channel, dtype=self.complex_dtype, device=self.device) + psd_n = torch.rand(freq, channel, channel, dtype=self.complex_dtype, device=self.device) + reference_channel = 0 + self._assert_consistency_complex(T.SoudenMVDR(), specgram, psd_s, psd_n, reference_channel) + class TransformsFloat32Only(TestBaseMixin): def test_rnnt_loss(self): diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index f7109c6531..2f08a4fd48 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1831,34 +1831,37 @@ def mvdr_weights_souden( .. properties:: Autograd TorchScript + Given the power spectral density (PSD) matrix of target speech :math:`\bf{\Phi}_{\textbf{SS}}`, + the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the + reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix + :math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as: + .. math:: \textbf{w}_{\text{MVDR}}(f) = \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)} {\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u} - where :math:`\bf{\Phi}_{\textbf{SS}}` and :math:`\bf{\Phi}_{\textbf{NN}}` - are the power spectral density (PSD) matrices of speech and noise, respectively. - :math:`\bf{u}` is a one-hot vector that represents the reference channel. Args: - psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech. - Tensor of dimension `(..., freq, channel, channel)` - psd_n (Tensor): The complex-valued power spectral density (PSD) matrix of noise. - Tensor of dimension `(..., freq, channel, channel)` - reference_channel (int or Tensor): Indicate the reference channel. - If the dtype is ``int``, it represent the reference channel index. - If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension + psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech. + Tensor with dimensions `(..., freq, channel, channel)`. + psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise. + Tensor with dimensions `(..., freq, channel, channel)`. + reference_channel (int or torch.Tensor): Specifies the reference channel. + If the dtype is ``int``, it represents the reference channel index. + If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension is one-hot. - diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n + diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``. (Default: ``True``) - diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading - (Default: ``1e-7``) - eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``) + diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading. + It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``) + eps (float, optional): Value to add to the denominator in the beamforming weight formula. + (Default: ``1e-8``) Returns: - Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel). + torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`. """ if diagonal_loading: - psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps) + psd_n = _tik_reg(psd_n, reg=diag_eps) numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (_compute_mat_trace(numerator)[..., None, None] + eps) diff --git a/torchaudio/transforms/__init__.py b/torchaudio/transforms/__init__.py index f94ff4eeab..e8de7e800e 100644 --- a/torchaudio/transforms/__init__.py +++ b/torchaudio/transforms/__init__.py @@ -24,6 +24,7 @@ RNNTLoss, PSD, MVDR, + SoudenMVDR, ) @@ -47,6 +48,7 @@ "RNNTLoss", "Resample", "SlidingWindowCmn", + "SoudenMVDR", "SpectralCentroid", "Spectrogram", "TimeMasking", diff --git a/torchaudio/transforms/_transforms.py b/torchaudio/transforms/_transforms.py index 7c091b6e1d..ba48b85c81 100644 --- a/torchaudio/transforms/_transforms.py +++ b/torchaudio/transforms/_transforms.py @@ -2,7 +2,7 @@ import math import warnings -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch from torch import Tensor @@ -2089,3 +2089,66 @@ def forward( specgram_enhanced.to(dtype) return specgram_enhanced + + +class SoudenMVDR(torch.nn.Module): + r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module + based on the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`]. + + .. devices:: CPU CUDA + + .. properties:: Autograd TorchScript + + Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the power spectral density (PSD) matrix + of target speech :math:`\bf{\Phi}_{\textbf{SS}}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and + a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel + complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as: + + .. math:: + \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f) + + where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin. + + The beamforming weight is computed by: + + .. math:: + \textbf{w}_{\text{MVDR}}(f) = + \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)} + {\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u} + """ + + def forward( + self, + specgram: Tensor, + psd_s: Tensor, + psd_n: Tensor, + reference_channel: Union[int, Tensor], + diagonal_loading: bool = True, + diag_eps: float = 1e-7, + eps: float = 1e-8, + ) -> torch.Tensor: + """ + Args: + specgram (torch.Tensor): Multi-channel complex-valued spectrum. + Tensor with dimensions `(..., channel, freq, time)`. + psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech. + Tensor with dimensions `(..., freq, channel, channel)`. + psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise. + Tensor with dimensions `(..., freq, channel, channel)`. + reference_channel (int or torch.Tensor): Specifies the reference channel. + If the dtype is ``int``, it represents the reference channel index. + If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension + is one-hot. + diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``. + (Default: ``True``) + diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading. + It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``) + eps (float, optional): Value to add to the denominator in the beamforming weight formula. + (Default: ``1e-8``) + + Returns: + torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`. + """ + w_mvdr = F.mvdr_weights_souden(psd_s, psd_n, reference_channel, diagonal_loading, diag_eps, eps) + spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram) + return spectrum_enhanced