Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RTFMVDR module #2368

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ Transforms are common audio transforms. They can be chained together using :clas

.. automethod:: forward

:hidden:`RTFMVDR`
-----------------

.. autoclass:: RTFMVDR

.. automethod:: forward

:hidden:`SoudenMVDR`
--------------------

Expand Down
10 changes: 10 additions & 0 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,16 @@ def test_mvdr(self, solution):
mask_n = torch.rand(spectrogram.shape[-2:])
self.assert_grad(transform, [spectrogram, mask_s, mask_n])

def test_rtf_mvdr(self):
transform = T.RTFMVDR()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
specgram = get_spectrogram(waveform, n_fft=400)
channel, freq, _ = specgram.shape
rtf = torch.rand(freq, channel, dtype=torch.cfloat)
psd_n = torch.rand(freq, channel, channel, dtype=torch.cfloat)
reference_channel = 0
self.assert_grad(transform, [specgram, rtf, psd_n, reference_channel])

def test_souden_mvdr(self):
transform = T.SoudenMVDR()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
Expand Down
19 changes: 19 additions & 0 deletions test/torchaudio_unittest/transforms/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,25 @@ def test_MVDR(self, multi_mask):

self.assertEqual(computed, expected)

def test_rtf_mvdr(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
batch_size, channel, freq, time = 3, 2, specgram.shape[-2], specgram.shape[-1]
specgram = specgram.reshape(batch_size, channel, freq, time)
rtf = torch.rand(batch_size, freq, channel, dtype=torch.cfloat)
psd_n = torch.rand(batch_size, freq, channel, channel, dtype=torch.cfloat)
reference_channel = 0
transform = T.RTFMVDR()

# Single then transform then batch
expected = [transform(specgram[i], rtf[i], psd_n[i], reference_channel) for i in range(batch_size)]
expected = torch.stack(expected)

# Batch then transform
computed = transform(specgram, rtf, psd_n, reference_channel)

self.assertEqual(computed, expected)

def test_souden_mvdr(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ def test_MVDR(self, solution, online):
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n)

def test_rtf_mvdr(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
channel, freq, _ = specgram.shape
rtf = torch.rand(freq, channel, dtype=self.complex_dtype, device=self.device)
psd_n = torch.rand(freq, channel, channel, dtype=self.complex_dtype, device=self.device)
reference_channel = 0
self._assert_consistency_complex(T.RTFMVDR(), specgram, rtf, psd_n, reference_channel)

def test_souden_mvdr(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
Expand Down
38 changes: 21 additions & 17 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,35 +1892,39 @@ def mvdr_weights_rtf(

.. properties:: Autograd TorchScript

Given the relative transfer function (RTF) matrix or the steering vector of target speech :math:`\bm{v}`,
the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
:math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:

.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
{{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
where :math:`\bm{v}` is the RTF or the steering vector.
:math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.

where :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.

Args:
rtf (Tensor): The complex-valued RTF vector of target speech.
Tensor of dimension `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor, optional): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
rtf (torch.Tensor): The complex-valued RTF vector of target speech.
Tensor with dimensions `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
If a non-None value is given, the MVDR weights will be normalized by ``rtf[..., reference_channel].conj()``
(Default: ``None``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading
(Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
carolineechen marked this conversation as resolved.
Show resolved Hide resolved
(Default: ``1e-8``)

Returns:
Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel).
torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
"""
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
psd_n = _tik_reg(psd_n, reg=diag_eps)
# numerator = psd_n.inv() @ stv
numerator = torch.linalg.solve(psd_n, rtf.unsqueeze(-1)).squeeze(-1) # (..., freq, channel)
# denominator = stv^H @ psd_n.inv() @ stv
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
RNNTLoss,
PSD,
MVDR,
RTFMVDR,
SoudenMVDR,
)

Expand All @@ -46,6 +47,7 @@
"PSD",
"PitchShift",
"RNNTLoss",
"RTFMVDR",
"Resample",
"SlidingWindowCmn",
"SoudenMVDR",
Expand Down
64 changes: 64 additions & 0 deletions torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2091,6 +2091,70 @@ def forward(
return specgram_enhanced


class RTFMVDR(torch.nn.Module):
r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module
based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.

.. devices:: CPU CUDA

.. properties:: Autograd TorchScript

Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the relative transfer function (RTF) matrix
or the steering vector of target speech :math:`\bm{v}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:

.. math::
\hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)

where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin,
:math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.

The beamforming weight is computed by:

.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
{{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
"""

def forward(
self,
specgram: Tensor,
rtf: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> Tensor:
"""
Args:
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor with dimensions `(..., channel, freq, time)`
rtf (torch.Tensor): The complex-valued RTF vector of target speech.
Tensor with dimensions `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)

Returns:
torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
"""
w_mvdr = F.mvdr_weights_rtf(rtf, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
return spectrum_enhanced


class SoudenMVDR(torch.nn.Module):
r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module
based on the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`].
Expand Down