diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index f7109c6531..2f08a4fd48 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1831,34 +1831,37 @@ def mvdr_weights_souden( .. properties:: Autograd TorchScript + Given the power spectral density (PSD) matrix of target speech :math:`\bf{\Phi}_{\textbf{SS}}`, + the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the + reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix + :math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as: + .. math:: \textbf{w}_{\text{MVDR}}(f) = \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)} {\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u} - where :math:`\bf{\Phi}_{\textbf{SS}}` and :math:`\bf{\Phi}_{\textbf{NN}}` - are the power spectral density (PSD) matrices of speech and noise, respectively. - :math:`\bf{u}` is a one-hot vector that represents the reference channel. Args: - psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech. - Tensor of dimension `(..., freq, channel, channel)` - psd_n (Tensor): The complex-valued power spectral density (PSD) matrix of noise. - Tensor of dimension `(..., freq, channel, channel)` - reference_channel (int or Tensor): Indicate the reference channel. - If the dtype is ``int``, it represent the reference channel index. - If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension + psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech. + Tensor with dimensions `(..., freq, channel, channel)`. + psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise. + Tensor with dimensions `(..., freq, channel, channel)`. + reference_channel (int or torch.Tensor): Specifies the reference channel. + If the dtype is ``int``, it represents the reference channel index. + If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension is one-hot. - diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n + diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``. (Default: ``True``) - diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading - (Default: ``1e-7``) - eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``) + diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading. + It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``) + eps (float, optional): Value to add to the denominator in the beamforming weight formula. + (Default: ``1e-8``) Returns: - Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel). + torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`. """ if diagonal_loading: - psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps) + psd_n = _tik_reg(psd_n, reg=diag_eps) numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s # ws: (..., C, C) / (...,) -> (..., C, C) ws = numerator / (_compute_mat_trace(numerator)[..., None, None] + eps) diff --git a/torchaudio/transforms/_transforms.py b/torchaudio/transforms/_transforms.py index 7b91bac2d5..dbb66f93cb 100644 --- a/torchaudio/transforms/_transforms.py +++ b/torchaudio/transforms/_transforms.py @@ -2099,14 +2099,22 @@ class SoudenMVDR(torch.nn.Module): .. properties:: Autograd TorchScript + Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the power spectral density (PSD) matrix + of target speech :math:`\bf{\Phi}_{\textbf{SS}}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and + a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel + complex-valued spectrum of the enhaned speech :math:`\hat{\textbf{S}}`. The formula is defined as: + + .. math:: + \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f) + + where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin. + + The beamforming weight is computed by: + .. math:: \textbf{w}_{\text{MVDR}}(f) = \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)} {\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u} - - where :math:`\bf{\Phi}_{\textbf{SS}}` and :math:`\bf{\Phi}_{\textbf{NN}}` - are the power spectral density (PSD) matrices of speech and noise, respectively. - :math:`\bf{u}` is a one-hot vector that represents the reference channel. """ def forward( @@ -2118,28 +2126,28 @@ def forward( diagonal_loading: bool = True, diag_eps: float = 1e-7, eps: float = 1e-8, - ) -> Tensor: + ) -> torch.Tensor: """ Args: - specgram (Tensor): Multi-channel complex-valued spectrum. - Tensor of dimension `(..., channel, freq, time)` - psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech. - Tensor of dimension `(..., freq, channel, channel)` - psd_n (Tensor): The complex-valued power spectral density (PSD) matrix of noise. - Tensor of dimension `(..., freq, channel, channel)` - reference_channel (int or Tensor): Indicate the reference channel. - If the dtype is ``int``, it represent the reference channel index. - If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension + specgram (torch.Tensor): Multi-channel complex-valued spectrum. + Tensor with dimensions `(..., channel, freq, time)`. + psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech. + Tensor with dimensions `(..., freq, channel, channel)`. + psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise. + Tensor with dimensions `(..., freq, channel, channel)`. + reference_channel (int or torch.Tensor): Specifies the reference channel. + If the dtype is ``int``, it represents the reference channel index. + If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension is one-hot. - diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n + diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``. (Default: ``True``) - diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading - (Default: ``1e-7``) - eps (float, optional): a value added to the denominator in the beamforming weight computation. + diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading. + It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``) + eps (float, optional): Value to add to the denominator in the beamforming weight formula. (Default: ``1e-8``) Returns: - Tensor: Single-channel complex-valued enhanced spectrum of dimension (..., freq, time). + torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`. """ w_mvdr = F.mvdr_weights_souden(psd_s, psd_n, reference_channel, diagonal_loading, diag_eps, eps) spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)