Skip to content

Commit

Permalink
refactor docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Feb 11, 2022
1 parent bad0a08 commit 4c1a539
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 40 deletions.
10 changes: 10 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,13 @@ @article{higuchi2017online
year={2017},
publisher={IEEE}
}
@article{capon1969high,
title={High-resolution frequency-wavenumber spectrum analysis},
author={Capon, Jack},
journal={Proceedings of the IEEE},
volume={57},
number={8},
pages={1408--1418},
year={1969},
publisher={IEEE}
}
41 changes: 20 additions & 21 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,17 +1687,16 @@ def compute_power_spectral_density_matrix(
"""Compute cross-channel power spectral density (PSD) matrix.
Args:
specgram (Tensor): multi-channel complex-valued STFT matrix.
specgram (Tensor): Multi-channel complex-valued spectrum.
Tensor of dimension `(..., channel, freq, time)`
mask (Tensor or None, optional): Time-Frequency mask for normalization.
Tensor of dimension `(..., freq, time)` if multi_mask is ``False`` or
of dimension `(..., channel, freq, time)` if multi_mask is ``True``.
mask (Tensor or None, optional): Real-valued Time-Frequency mask
for normalization. Tensor of dimension `(..., freq, time)`
(Default: ``None``)
normalize (bool, optional): whether normalize the mask along the time dimension.
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-10``)
Returns:
Tensor: PSD matrix of the input STFT matrix.
Tensor: The complex-valued PSD matrix of the input spectrum.
Tensor of dimension `(..., freq, channel, channel)`
"""
specgram = specgram.transpose(-3, -2) # shape (freq, channel, time)
Expand All @@ -1724,7 +1723,7 @@ def compute_mvdr_weights_souden(
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> Tensor:
r"""Compute the Minimum Variance Distortionless Response (MVDR) beamforming weights
r"""Compute the Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) beamforming weights
by the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`].
.. math::
Expand All @@ -1737,9 +1736,9 @@ def compute_mvdr_weights_souden(
:math:`\bf{u}` is an one-hot vector that represents the reference channel.
Args:
psd_s (Tensor): The covariance matrix of target speech.
psd_s (Tensor): The complex-valued covariance matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): The covariance matrix of noise.
psd_n (Tensor): The complex-valued covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
Expand All @@ -1751,7 +1750,7 @@ def compute_mvdr_weights_souden(
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``)
Returns:
Tensor: the mvdr beamforming weight matrix of dimension (..., freq, channel).
Tensor: the complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel).
"""
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
Expand All @@ -1776,7 +1775,7 @@ def compute_mvdr_weights_rtf(
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> Tensor:
r"""Compute the Minimum Variance Distortionless Response (MVDR) beamforming weights
r"""Compute the Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) beamforming weights
based on the relative transfer function (RTF) and PSD matrix of noise.
.. math::
Expand All @@ -1788,9 +1787,9 @@ def compute_mvdr_weights_rtf(
:math:`.^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
Args:
rtf (Tensor): RTF vector of target speech.
rtf (Tensor): The complex-valued RTF vector of target speech.
Tensor of dimension `(..., freq, channel)`.
psd_n (torch.Tensor): The covariance matrix of noise.
psd_n (torch.Tensor): The complex-valued covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
Expand All @@ -1803,7 +1802,7 @@ def compute_mvdr_weights_rtf(
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``)
Returns:
Tensor: The MVDR beamforming weight matrix of dimension (..., freq, channel).
Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel).
"""
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
Expand All @@ -1827,11 +1826,11 @@ def compute_rtf_evd(psd_s: Tensor) -> Tensor:
r"""Estimate the relative transfer function (RTF) or the steering vector by eigenvalue decomposition.
Args:
psd_s (Tensor): The covariance matrix of target speech.
psd_s (Tensor): The complex-valued covariance matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
Returns:
Tensor: The estimated RTF of target speech.
Tensor: The estimated complex-valued RTF of target speech.
Tensor of dimension `(..., freq, channel)`
"""
w, v = torch.linalg.eigh(psd_s) # (..., freq, channel, channel)
Expand All @@ -1845,17 +1844,17 @@ def compute_rtf_power(
r"""Estimate the relative transfer function (RTF) or the steering vector by the power method.
Args:
psd_s (Tensor): The covariance matrix of target speech.
psd_s (Tensor): The complex-valued covariance matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): The covariance matrix of noise.
psd_n (Tensor): The complex-valued covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`.
iterations (int): number of iterations in power method. (Default: ``3``)
Returns:
Tensor: the estimated RTF of target speech
Tensor: the estimated complex-valued RTF of target speech
Tensor of dimension `(..., freq, channel)`
"""
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
Expand Down Expand Up @@ -1883,13 +1882,13 @@ def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
:math:`\textbf{Y}` is the multi-channel spectrogram for the :math:`f`-th frequency bin.
Args:
beamform_weights (Tensor): The beamforming weight matrix.
beamform_weights (Tensor): The complex-valued beamforming weight matrix.
Tensor of dimension `(..., freq, channel)`
specgram (Tensor): The multi-channel noisy spectrogram.
specgram (Tensor): The multi-channel complex-valued noisy spectrum.
Tensor of dimension `(..., channel, freq, time)`
Returns:
Tensor: The single-channel enhanced spectrogram.
Tensor: The single-channel complex-valued enhanced spectrum.
Tensor of dimension `(..., freq, time)`
"""
# (..., channel) x (..., channel, freq, time) -> (..., freq, time)
Expand Down
39 changes: 20 additions & 19 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,14 +1598,14 @@ def __init__(self, multi_mask: bool = False, normalize: bool = True, eps: float
def forward(self, specgram: Tensor, mask: Optional[Tensor] = None):
"""
Args:
specgram (Tensor): multi-channel complex-valued STFT matrix.
specgram (Tensor): Multi-channel complex-valued spectrum.
Tensor of dimension `(..., channel, freq, time)`
mask (Tensor or None, optional): Time-Frequency mask for normalization.
mask (Tensor or None, optional): Real-valued Time-Frequency mask for normalization.
Tensor of dimension `(..., freq, time)` if multi_mask is ``False`` or
of dimension `(..., channel, freq, time)` if multi_mask is ``True``
Returns:
Tensor: PSD matrix of the input STFT matrix.
Tensor: Complex-valued PSD matrix of the input spectrum.
Tensor of dimension `(..., freq, channel, channel)`
"""
if mask is not None and self.multi_mask:
Expand All @@ -1615,7 +1615,8 @@ def forward(self, specgram: Tensor, mask: Optional[Tensor] = None):


class MVDR(torch.nn.Module):
"""Minimum Variance Distortionless Response (MVDR) module that performs MVDR beamforming with Time-Frequency masks.
"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module
that performs MVDR beamforming with Time-Frequency masks.
Based on https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/beamformer.py
Expand Down Expand Up @@ -1645,12 +1646,12 @@ class MVDR(torch.nn.Module):
[:footcite:`higuchi2016robust`] or the *power method* [:footcite:`mises1929praktische`] to get the
steering vector from the PSD matrix of speech.
After estimating the beamforming weight, the enhanced Short-time Fourier Transform (STFT) is obtained by
After estimating the beamforming weight, the enhanced spectrum is obtained by
.. math::
\\hat{\\bf{S}} = {\\bf{w}^\\mathsf{H}}{\\bf{Y}}, {\\bf{w}} \\in \\mathbb{C}^{M \\times F}
where :math:`\\bf{Y}` and :math:`\\hat{\\bf{S}}` are the STFT of the multi-channel noisy speech and\
where :math:`\\bf{Y}` and :math:`\\hat{\\bf{S}}` are the spectrum of the multi-channel noisy speech and\
the single-channel enhanced speech, respectively.
For online streaming audio, we provide a *recursive method* [:footcite:`higuchi2017online`] to update the
Expand Down Expand Up @@ -1726,13 +1727,13 @@ def _get_updated_psds(
r"""Recursively update the MVDR beamforming vector.
Args:
psd_s (Tensor): psd matrix of target speech
psd_n (Tensor): psd matrix of noise
mask_s (Tensor): T-F mask of target speech
mask_n (Tensor): T-F mask of noise
psd_s (Tensor): Complex-valued PSD matrix of target speech
psd_n (Tensor): Complex-valued PSD matrix of noise
mask_s (Tensor): Real-valued T-F mask of target speech
mask_n (Tensor): Real-valued T-F mask of noise
Returns:
Tensor: the updated PSD matrix of speech
Tensor: the updated PSD matrix of target speech
Tensor: the updated PSD matrix of noise
"""
if self.multi_mask:
Expand All @@ -1755,8 +1756,8 @@ def _get_updated_psd_speech(self, psd_s: Tensor, mask_s: Tensor) -> Tensor:
r"""Update PSD of speech recursively.
Args:
psd_s (Tensor): PSD matrix of target speech
mask_s (Tensor): T-F mask of target speech
psd_s (Tensor): Complex-valued PSD matrix of target speech
mask_s (Tensor): Real-valued T-F mask of target speech
Returns:
Tensor:the updated PSD matrix of speech
Expand All @@ -1770,8 +1771,8 @@ def _get_updated_psd_noise(self, psd_n: Tensor, mask_n: Tensor) -> Tensor:
r"""Update PSD of noise recursively.
Args:
psd_n (Tensor): PSD matrix of target noise
mask_n (Tensor): T-F mask of target noise
psd_n (Tensor): Complex-valued PSD matrix of target noise
mask_n (Tensor): Real-valued T-F mask of target noise
Returns:
Tensor: the updated PSD matrix of noise
Expand All @@ -1785,18 +1786,18 @@ def forward(self, specgram: Tensor, mask_s: Tensor, mask_n: Optional[Tensor] = N
"""Perform MVDR beamforming.
Args:
specgram (Tensor): the multi-channel STF of the noisy speech.
specgram (Tensor): The multi-channel complex-valued spectrum of the noisy speech.
Tensor of dimension `(..., channel, freq, time)`
mask_s (Tensor): Time-Frequency mask of target speech.
mask_s (Tensor): Real-valued Time-Frequency mask of target speech.
Tensor of dimension `(..., freq, time)` if multi_mask is ``False``
or or dimension `(..., channel, freq, time)` if multi_mask is ``True``
mask_n (Tensor or None, optional): Time-Frequency mask of noise.
mask_n (Tensor or None, optional): Real-valued Time-Frequency mask of noise.
Tensor of dimension `(..., freq, time)` if multi_mask is ``False``
or or dimension `(..., channel, freq, time)` if multi_mask is ``True``
(Default: None)
Returns:
Tensor: The single-channel STFT of the enhanced speech.
Tensor: The single-channel complex-valued spectrum of the enhanced speech.
Tensor of dimension `(..., freq, time)`
"""
dtype = specgram.dtype
Expand Down

0 comments on commit 4c1a539

Please sign in to comment.