diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 2bb19e2c86..4fe5acce26 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -197,6 +197,13 @@ Transforms are common audio transforms. They can be chained together using :clas .. automethod:: forward +:hidden:`RTFMVDR` +----------------- + +.. autoclass:: RTFMVDR + + .. automethod:: forward + :hidden:`SoudenMVDR` -------------------- diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 7dc587f5df..ce3899d1f0 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -303,6 +303,16 @@ def test_mvdr(self, solution): mask_n = torch.rand(spectrogram.shape[-2:]) self.assert_grad(transform, [spectrogram, mask_s, mask_n]) + def test_rtf_mvdr(self): + transform = T.RTFMVDR() + waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) + specgram = get_spectrogram(waveform, n_fft=400) + channel, freq, _ = specgram.shape + rtf = torch.rand(freq, channel, dtype=torch.cfloat) + psd_n = torch.rand(freq, channel, channel, dtype=torch.cfloat) + reference_channel = 0 + self.assert_grad(transform, [specgram, rtf, psd_n, reference_channel]) + def test_souden_mvdr(self): transform = T.SoudenMVDR() waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) diff --git a/test/torchaudio_unittest/transforms/batch_consistency_test.py b/test/torchaudio_unittest/transforms/batch_consistency_test.py index 1657e340e7..7a3d7ad4ae 100644 --- a/test/torchaudio_unittest/transforms/batch_consistency_test.py +++ b/test/torchaudio_unittest/transforms/batch_consistency_test.py @@ -220,6 +220,25 @@ def test_MVDR(self, multi_mask): self.assertEqual(computed, expected) + def test_rtf_mvdr(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + specgram = common_utils.get_spectrogram(waveform, n_fft=400) + batch_size, channel, freq, time = 3, 2, specgram.shape[-2], specgram.shape[-1] + specgram = specgram.reshape(batch_size, channel, freq, time) + rtf = torch.rand(batch_size, freq, channel, dtype=torch.cfloat) + psd_n = torch.rand(batch_size, freq, channel, channel, dtype=torch.cfloat) + reference_channel = 0 + transform = T.RTFMVDR() + + # Single then transform then batch + expected = [transform(specgram[i], rtf[i], psd_n[i], reference_channel) for i in range(batch_size)] + expected = torch.stack(expected) + + # Batch then transform + computed = transform(specgram, rtf, psd_n, reference_channel) + + self.assertEqual(computed, expected) + def test_souden_mvdr(self): waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) specgram = common_utils.get_spectrogram(waveform, n_fft=400) diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 704c1c6bd5..c21022fa6f 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -175,6 +175,15 @@ def test_MVDR(self, solution, online): mask_n = torch.rand(spectrogram.shape[-2:], device=self.device) self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n) + def test_rtf_mvdr(self): + tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4) + specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) + channel, freq, _ = specgram.shape + rtf = torch.rand(freq, channel, dtype=self.complex_dtype, device=self.device) + psd_n = torch.rand(freq, channel, channel, dtype=self.complex_dtype, device=self.device) + reference_channel = 0 + self._assert_consistency_complex(T.RTFMVDR(), specgram, rtf, psd_n, reference_channel) + def test_souden_mvdr(self): tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4) specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 2872cbdd06..79aba1107f 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1892,35 +1892,39 @@ def mvdr_weights_rtf( .. properties:: Autograd TorchScript + Given the relative transfer function (RTF) matrix or the steering vector of target speech :math:`\bm{v}`, + the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the + reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix + :math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as: + .. math:: \textbf{w}_{\text{MVDR}}(f) = \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}} {{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)} - where :math:`\bm{v}` is the RTF or the steering vector. - :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation. + + where :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation. Args: - rtf (Tensor): The complex-valued RTF vector of target speech. - Tensor of dimension `(..., freq, channel)`. - psd_n (torch.Tensor): The complex-valued covariance matrix of noise. - Tensor of dimension `(..., freq, channel, channel)` - reference_channel (int or Tensor, optional): Indicate the reference channel. - If the dtype is ``int``, it represent the reference channel index. - If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension + rtf (torch.Tensor): The complex-valued RTF vector of target speech. + Tensor with dimensions `(..., freq, channel)`. + psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise. + Tensor with dimensions `(..., freq, channel, channel)`. + reference_channel (int or torch.Tensor): Specifies the reference channel. + If the dtype is ``int``, it represents the reference channel index. + If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension is one-hot. - If a non-None value is given, the MVDR weights will be normalized by ``rtf[..., reference_channel].conj()`` - (Default: ``None``) - diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n + diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``. (Default: ``True``) - diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading - (Default: ``1e-7``) - eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``) + diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading. + It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``) + eps (float, optional): Value to add to the denominator in the beamforming weight formula. + (Default: ``1e-8``) Returns: - Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel). + torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`. """ if diagonal_loading: - psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps) + psd_n = _tik_reg(psd_n, reg=diag_eps) # numerator = psd_n.inv() @ stv numerator = torch.linalg.solve(psd_n, rtf.unsqueeze(-1)).squeeze(-1) # (..., freq, channel) # denominator = stv^H @ psd_n.inv() @ stv diff --git a/torchaudio/transforms/__init__.py b/torchaudio/transforms/__init__.py index e8de7e800e..a84ed64ce0 100644 --- a/torchaudio/transforms/__init__.py +++ b/torchaudio/transforms/__init__.py @@ -24,6 +24,7 @@ RNNTLoss, PSD, MVDR, + RTFMVDR, SoudenMVDR, ) @@ -46,6 +47,7 @@ "PSD", "PitchShift", "RNNTLoss", + "RTFMVDR", "Resample", "SlidingWindowCmn", "SoudenMVDR", diff --git a/torchaudio/transforms/_transforms.py b/torchaudio/transforms/_transforms.py index ba48b85c81..4cd96e6060 100644 --- a/torchaudio/transforms/_transforms.py +++ b/torchaudio/transforms/_transforms.py @@ -2091,6 +2091,70 @@ def forward( return specgram_enhanced +class RTFMVDR(torch.nn.Module): + r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module + based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise. + + .. devices:: CPU CUDA + + .. properties:: Autograd TorchScript + + Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the relative transfer function (RTF) matrix + or the steering vector of target speech :math:`\bm{v}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and + a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel + complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as: + + .. math:: + \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f) + + where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin, + :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation. + + The beamforming weight is computed by: + + .. math:: + \textbf{w}_{\text{MVDR}}(f) = + \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}} + {{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)} + """ + + def forward( + self, + specgram: Tensor, + rtf: Tensor, + psd_n: Tensor, + reference_channel: Union[int, Tensor], + diagonal_loading: bool = True, + diag_eps: float = 1e-7, + eps: float = 1e-8, + ) -> Tensor: + """ + Args: + specgram (torch.Tensor): Multi-channel complex-valued spectrum. + Tensor with dimensions `(..., channel, freq, time)` + rtf (torch.Tensor): The complex-valued RTF vector of target speech. + Tensor with dimensions `(..., freq, channel)`. + psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise. + Tensor with dimensions `(..., freq, channel, channel)`. + reference_channel (int or torch.Tensor): Specifies the reference channel. + If the dtype is ``int``, it represents the reference channel index. + If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension + is one-hot. + diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``. + (Default: ``True``) + diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading. + It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``) + eps (float, optional): Value to add to the denominator in the beamforming weight formula. + (Default: ``1e-8``) + + Returns: + torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`. + """ + w_mvdr = F.mvdr_weights_rtf(rtf, psd_n, reference_channel, diagonal_loading, diag_eps, eps) + spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram) + return spectrum_enhanced + + class SoudenMVDR(torch.nn.Module): r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module based on the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`].