Skip to content

Commit

Permalink
fix docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed May 6, 2022
1 parent ad231cd commit a29ba8e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 35 deletions.
35 changes: 19 additions & 16 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1831,34 +1831,37 @@ def mvdr_weights_souden(
.. properties:: Autograd TorchScript
Given the power spectral density (PSD) matrix of target speech :math:`\bf{\Phi}_{\textbf{SS}}`,
the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
:math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:
.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
{\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
where :math:`\bf{\Phi}_{\textbf{SS}}` and :math:`\bf{\Phi}_{\textbf{NN}}`
are the power spectral density (PSD) matrices of speech and noise, respectively.
:math:`\bf{u}` is a one-hot vector that represents the reference channel.
Args:
psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor with dimensions `(..., freq, channel, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading
(Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns:
Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel).
torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
"""
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
psd_n = _tik_reg(psd_n, reg=diag_eps)
numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
# ws: (..., C, C) / (...,) -> (..., C, C)
ws = numerator / (_compute_mat_trace(numerator)[..., None, None] + eps)
Expand Down
46 changes: 27 additions & 19 deletions torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2099,14 +2099,22 @@ class SoudenMVDR(torch.nn.Module):
.. properties:: Autograd TorchScript
Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the power spectral density (PSD) matrix
of target speech :math:`\bf{\Phi}_{\textbf{SS}}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
complex-valued spectrum of the enhaned speech :math:`\hat{\textbf{S}}`. The formula is defined as:
.. math::
\hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin.
The beamforming weight is computed by:
.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
{\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
where :math:`\bf{\Phi}_{\textbf{SS}}` and :math:`\bf{\Phi}_{\textbf{NN}}`
are the power spectral density (PSD) matrices of speech and noise, respectively.
:math:`\bf{u}` is a one-hot vector that represents the reference channel.
"""

def forward(
Expand All @@ -2118,28 +2126,28 @@ def forward(
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> Tensor:
) -> torch.Tensor:
"""
Args:
specgram (Tensor): Multi-channel complex-valued spectrum.
Tensor of dimension `(..., channel, freq, time)`
psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor with dimensions `(..., channel, freq, time)`.
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor with dimensions `(..., freq, channel, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading
(Default: ``1e-7``)
eps (float, optional): a value added to the denominator in the beamforming weight computation.
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns:
Tensor: Single-channel complex-valued enhanced spectrum of dimension (..., freq, time).
torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
"""
w_mvdr = F.mvdr_weights_souden(psd_s, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
Expand Down

0 comments on commit a29ba8e

Please sign in to comment.