Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SoudenMVDR module #2367

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ Transforms are common audio transforms. They can be chained together using :clas

.. automethod:: forward

:hidden:`SoudenMVDR`
--------------------

.. autoclass:: SoudenMVDR

.. automethod:: forward

References
~~~~~~~~~~

Expand Down
10 changes: 10 additions & 0 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,16 @@ def test_mvdr(self, solution):
mask_n = torch.rand(spectrogram.shape[-2:])
self.assert_grad(transform, [spectrogram, mask_s, mask_n])

def test_souden_mvdr(self):
transform = T.SoudenMVDR()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
specgram = get_spectrogram(waveform, n_fft=400)
channel, freq, _ = specgram.shape
psd_s = torch.rand(freq, channel, channel, dtype=torch.cfloat)
psd_n = torch.rand(freq, channel, channel, dtype=torch.cfloat)
reference_channel = 0
self.assert_grad(transform, [specgram, psd_s, psd_n, reference_channel])


class AutogradTestFloat32(TestBaseMixin):
def assert_grad(
Expand Down
19 changes: 19 additions & 0 deletions test/torchaudio_unittest/transforms/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,22 @@ def test_MVDR(self, multi_mask):
computed = transform(specgram, mask_s, mask_n)

self.assertEqual(computed, expected)

def test_souden_mvdr(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
batch_size, channel, freq, time = 3, 2, specgram.shape[-2], specgram.shape[-1]
specgram = specgram.reshape(batch_size, channel, freq, time)
psd_s = torch.rand(batch_size, freq, channel, channel, dtype=torch.cfloat)
psd_n = torch.rand(batch_size, freq, channel, channel, dtype=torch.cfloat)
reference_channel = 0
transform = T.SoudenMVDR()

# Single then transform then batch
expected = [transform(specgram[i], psd_s[i], psd_n[i], reference_channel) for i in range(batch_size)]
expected = torch.stack(expected)

# Batch then transform
computed = transform(specgram, psd_s, psd_n, reference_channel)

self.assertEqual(computed, expected)
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ def test_MVDR(self, solution, online):
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n)

def test_souden_mvdr(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
channel, freq, _ = specgram.shape
psd_s = torch.rand(freq, channel, channel, dtype=self.complex_dtype, device=self.device)
psd_n = torch.rand(freq, channel, channel, dtype=self.complex_dtype, device=self.device)
reference_channel = 0
self._assert_consistency_complex(T.SoudenMVDR(), specgram, psd_s, psd_n, reference_channel)


class TransformsFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
Expand Down
35 changes: 19 additions & 16 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1831,34 +1831,37 @@ def mvdr_weights_souden(

.. properties:: Autograd TorchScript

Given the power spectral density (PSD) matrix of target speech :math:`\bf{\Phi}_{\textbf{SS}}`,
the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
:math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:

.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
{\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
where :math:`\bf{\Phi}_{\textbf{SS}}` and :math:`\bf{\Phi}_{\textbf{NN}}`
are the power spectral density (PSD) matrices of speech and noise, respectively.
:math:`\bf{u}` is a one-hot vector that represents the reference channel.

Args:
psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor with dimensions `(..., freq, channel, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading
(Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
carolineechen marked this conversation as resolved.
Show resolved Hide resolved
(Default: ``1e-8``)

Returns:
Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel).
torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
"""
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
psd_n = _tik_reg(psd_n, reg=diag_eps)
numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
# ws: (..., C, C) / (...,) -> (..., C, C)
ws = numerator / (_compute_mat_trace(numerator)[..., None, None] + eps)
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
RNNTLoss,
PSD,
MVDR,
SoudenMVDR,
)


Expand All @@ -47,6 +48,7 @@
"RNNTLoss",
"Resample",
"SlidingWindowCmn",
"SoudenMVDR",
"SpectralCentroid",
"Spectrogram",
"TimeMasking",
Expand Down
65 changes: 64 additions & 1 deletion torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
import warnings
from typing import Callable, Optional
from typing import Callable, Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -2089,3 +2089,66 @@ def forward(

specgram_enhanced.to(dtype)
return specgram_enhanced


class SoudenMVDR(torch.nn.Module):
r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module
based on the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`].

.. devices:: CPU CUDA

.. properties:: Autograd TorchScript

nateanl marked this conversation as resolved.
Show resolved Hide resolved
Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the power spectral density (PSD) matrix
of target speech :math:`\bf{\Phi}_{\textbf{SS}}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:

.. math::
\hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)

where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin.

The beamforming weight is computed by:

.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
{\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is "tr" more standard?

Suggested change
{\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
{\text{tr}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw different usages in different publications, trace in https://www.merl.com/publications/docs/TR2016-072.pdf, Trace in https://arxiv.org/pdf/2005.10479.pdf, and Tr in https://ieeexplore.ieee.org/abstract/document/7952756.
For a better understanding we can put Trace here.

"""

def forward(
self,
specgram: Tensor,
psd_s: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> torch.Tensor:
"""
Args:
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor with dimensions `(..., channel, freq, time)`.
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor with dimensions `(..., freq, channel, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)

Returns:
torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
"""
w_mvdr = F.mvdr_weights_souden(psd_s, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
return spectrum_enhanced