Skip to content

Commit

Permalink
Add RTFMVDR module (#2368)
Browse files Browse the repository at this point in the history
Summary:
Add a new design of MVDR module.
The RTFMVDR module supports the method based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.
The input arguments are:
- multi-channel spectrum.
- RTF vector of the target speech
- PSD matrix of noise.
- reference channel in the microphone array.
- diagonal_loading option to enable or disable diagonal loading in matrix inverse computation.
- diag_eps for computing the inverse of the matrix.
- eps for computing the beamforming weight.
The output of the module is the single-channel complex-valued spectrum for the enhanced speech.

Pull Request resolved: #2368

Reviewed By: carolineechen

Differential Revision: D36214940

Pulled By: nateanl

fbshipit-source-id: 5f29f778663c96591e1b520b15f7876d07116937
  • Loading branch information
nateanl authored and facebook-github-bot committed May 10, 2022
1 parent da1e83c commit 4b021ae
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 17 deletions.
7 changes: 7 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ Transforms are common audio transforms. They can be chained together using :clas

.. automethod:: forward

:hidden:`RTFMVDR`
-----------------

.. autoclass:: RTFMVDR

.. automethod:: forward

:hidden:`SoudenMVDR`
--------------------

Expand Down
10 changes: 10 additions & 0 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,16 @@ def test_mvdr(self, solution):
mask_n = torch.rand(spectrogram.shape[-2:])
self.assert_grad(transform, [spectrogram, mask_s, mask_n])

def test_rtf_mvdr(self):
transform = T.RTFMVDR()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
specgram = get_spectrogram(waveform, n_fft=400)
channel, freq, _ = specgram.shape
rtf = torch.rand(freq, channel, dtype=torch.cfloat)
psd_n = torch.rand(freq, channel, channel, dtype=torch.cfloat)
reference_channel = 0
self.assert_grad(transform, [specgram, rtf, psd_n, reference_channel])

def test_souden_mvdr(self):
transform = T.SoudenMVDR()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
Expand Down
19 changes: 19 additions & 0 deletions test/torchaudio_unittest/transforms/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,25 @@ def test_MVDR(self, multi_mask):

self.assertEqual(computed, expected)

def test_rtf_mvdr(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
batch_size, channel, freq, time = 3, 2, specgram.shape[-2], specgram.shape[-1]
specgram = specgram.reshape(batch_size, channel, freq, time)
rtf = torch.rand(batch_size, freq, channel, dtype=torch.cfloat)
psd_n = torch.rand(batch_size, freq, channel, channel, dtype=torch.cfloat)
reference_channel = 0
transform = T.RTFMVDR()

# Single then transform then batch
expected = [transform(specgram[i], rtf[i], psd_n[i], reference_channel) for i in range(batch_size)]
expected = torch.stack(expected)

# Batch then transform
computed = transform(specgram, rtf, psd_n, reference_channel)

self.assertEqual(computed, expected)

def test_souden_mvdr(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ def test_MVDR(self, solution, online):
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n)

def test_rtf_mvdr(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
channel, freq, _ = specgram.shape
rtf = torch.rand(freq, channel, dtype=self.complex_dtype, device=self.device)
psd_n = torch.rand(freq, channel, channel, dtype=self.complex_dtype, device=self.device)
reference_channel = 0
self._assert_consistency_complex(T.RTFMVDR(), specgram, rtf, psd_n, reference_channel)

def test_souden_mvdr(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
Expand Down
38 changes: 21 additions & 17 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,35 +1892,39 @@ def mvdr_weights_rtf(
.. properties:: Autograd TorchScript
Given the relative transfer function (RTF) matrix or the steering vector of target speech :math:`\bm{v}`,
the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
:math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:
.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
{{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
where :math:`\bm{v}` is the RTF or the steering vector.
:math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
where :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
Args:
rtf (Tensor): The complex-valued RTF vector of target speech.
Tensor of dimension `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor, optional): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
rtf (torch.Tensor): The complex-valued RTF vector of target speech.
Tensor with dimensions `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
If a non-None value is given, the MVDR weights will be normalized by ``rtf[..., reference_channel].conj()``
(Default: ``None``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading
(Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns:
Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel).
torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
"""
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
psd_n = _tik_reg(psd_n, reg=diag_eps)
# numerator = psd_n.inv() @ stv
numerator = torch.linalg.solve(psd_n, rtf.unsqueeze(-1)).squeeze(-1) # (..., freq, channel)
# denominator = stv^H @ psd_n.inv() @ stv
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
RNNTLoss,
PSD,
MVDR,
RTFMVDR,
SoudenMVDR,
)

Expand All @@ -46,6 +47,7 @@
"PSD",
"PitchShift",
"RNNTLoss",
"RTFMVDR",
"Resample",
"SlidingWindowCmn",
"SoudenMVDR",
Expand Down
64 changes: 64 additions & 0 deletions torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2091,6 +2091,70 @@ def forward(
return specgram_enhanced


class RTFMVDR(torch.nn.Module):
r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module
based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the relative transfer function (RTF) matrix
or the steering vector of target speech :math:`\bm{v}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:
.. math::
\hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin,
:math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
The beamforming weight is computed by:
.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
{{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
"""

def forward(
self,
specgram: Tensor,
rtf: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> Tensor:
"""
Args:
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor with dimensions `(..., channel, freq, time)`
rtf (torch.Tensor): The complex-valued RTF vector of target speech.
Tensor with dimensions `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns:
torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
"""
w_mvdr = F.mvdr_weights_rtf(rtf, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
return spectrum_enhanced


class SoudenMVDR(torch.nn.Module):
r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module
based on the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`].
Expand Down

0 comments on commit 4b021ae

Please sign in to comment.