Skip to content

Commit

Permalink
Add apply_beamforming to torchaudio.functional (pytorch#2232)
Browse files Browse the repository at this point in the history
Summary:
This PR adds ``apply_beamforming`` method to ``torchaudio.functional``.
The method employs the beamforming weight to the multi-channel noisy spectrum to obtain the single-channel enhanced spectrum.
The input arguments are the complex-valued beamforming weight Tensor and the multi-channel noisy spectrum.

Pull Request resolved: pytorch#2232

Reviewed By: mthrok

Differential Revision: D34474561

Pulled By: nateanl

fbshipit-source-id: 2910251a8f111e65375dfb50495b6a415113f06d
  • Loading branch information
nateanl authored and xiaohui-zhang committed May 4, 2022
1 parent 4862864 commit d8a8357
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ rtf_power

.. autofunction:: rtf_power

apply_beamforming
-----------------

.. autofunction:: apply_beamforming

:hidden:`Loss`
~~~~~~~~~~~~~~

Expand Down
5 changes: 5 additions & 0 deletions test/torchaudio_unittest/common_utils/beamform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,8 @@ def rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter):
rtf = psd_n @ rtf
rtf = rtf.squeeze(-1)
return rtf


def apply_beamforming_numpy(beamform_weights, specgram):
specgram_enhanced = np.einsum("...fc,...cft->...ft", beamform_weights.conj(), specgram)
return specgram_enhanced
13 changes: 13 additions & 0 deletions test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
get_spectrogram,
rnnt_utils,
)

Expand Down Expand Up @@ -333,6 +334,18 @@ def test_rtf_power_with_tensor(self, n_iter):
reference_channel[0].fill_(1)
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))

def test_apply_beamforming(self):
torch.random.manual_seed(2434)
sr = 8000
n_fft = 400
batch_size, num_channels = 2, 3
n_fft_bin = n_fft // 2 + 1
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=batch_size * num_channels)
specgram = get_spectrogram(x, n_fft=n_fft, hop_length=100)
specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1))
beamform_weights = torch.rand(n_fft_bin, num_channels, dtype=torch.cfloat)
self.assert_grad(F.apply_beamforming, (beamform_weights, specgram))


class AutogradFloat32(TestBaseMixin):
def assert_grad(
Expand Down
12 changes: 12 additions & 0 deletions test/torchaudio_unittest/functional/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,15 @@ def test_rtf_power_with_tensor(self, n_iter):
}
func = partial(F.rtf_power, **kwargs)
self.assert_batch_consistency(func, (psd_speech, psd_noise, reference_channel))

def test_apply_beamforming(self):
torch.random.manual_seed(2434)
sr = 8000
n_fft = 400
batch_size, num_channels = 2, 3
n_fft_bin = n_fft // 2 + 1
x = common_utils.get_whitenoise(sample_rate=sr, duration=0.05, n_channels=batch_size * num_channels)
specgram = common_utils.get_spectrogram(x, n_fft=n_fft, hop_length=100)
specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1))
beamform_weights = torch.rand(batch_size, n_fft_bin, num_channels, dtype=torch.cfloat)
self.assert_batch_consistency(F.apply_beamforming, (beamform_weights, specgram))
21 changes: 21 additions & 0 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,27 @@ def test_rtf_power_with_tensor(self, n_iter):
)
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)

def test_apply_beamforming(self):
"""Verify ``F.apply_beamforming`` method by numpy implementation.
Given the multi-channel complex-valued spectrum and complex-valued
beamforming weights (Tensor of dimension `(..., freq, channel)`) as inputs,
``F.apply_beamforming`` outputs the single-channel complex-valued enhanced
spectrum, which should be identical to the output of ``apply_beamforming_numpy``.
"""
channel = 4
n_fft_bin = 10
frame = 5
beamform_weights = np.random.random((n_fft_bin, channel)) + np.random.random((n_fft_bin, channel)) * 1j
specgram = np.random.random((channel, n_fft_bin, frame)) + np.random.random((channel, n_fft_bin, frame)) * 1j
specgram_enhanced = beamform_utils.apply_beamforming_numpy(beamform_weights, specgram)
specgram_enhanced_audio = F.apply_beamforming(
torch.tensor(beamform_weights, dtype=self.complex_dtype, device=self.device),
torch.tensor(specgram, dtype=self.complex_dtype, device=self.device),
)
self.assertEqual(
torch.tensor(specgram_enhanced, dtype=self.complex_dtype, device=self.device), specgram_enhanced_audio
)


class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,14 @@ def test_rtf_power_with_tensor(self, n_iter):
reference_channel[..., 0].fill_(1)
self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))

def test_apply_beamforming(self):
num_channels = 4
n_fft_bin = 201
num_frames = 10
beamform_weights = torch.rand(n_fft_bin, num_channels, dtype=self.complex_dtype, device=self.device)
specgram = torch.rand(num_channels, n_fft_bin, num_frames, dtype=self.complex_dtype, device=self.device)
self._assert_consistency_complex(F.apply_beamforming, (beamform_weights, specgram))


class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
mvdr_weights_rtf,
rtf_evd,
rtf_power,
apply_beamforming,
)

__all__ = [
Expand Down Expand Up @@ -104,4 +105,5 @@
"mvdr_weights_rtf",
"rtf_evd",
"rtf_power",
"apply_beamforming",
]
24 changes: 24 additions & 0 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"mvdr_weights_rtf",
"rtf_evd",
"rtf_power",
"apply_beamforming",
]


Expand Down Expand Up @@ -1886,3 +1887,26 @@ def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor
# which is psd_n @ phi @ ref_channel
rtf = torch.matmul(psd_n, rtf)
return rtf.squeeze(-1)


def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
r"""Apply the beamforming weight to the multi-channel noisy spectrum to obtain the single-channel enhanced spectrum.
.. math::
\hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
where :math:`\textbf{w}_{\text{bf}}(f)` is the beamforming weight for the :math:`f`-th frequency bin,
:math:`\textbf{Y}` is the multi-channel spectrum for the :math:`f`-th frequency bin.
Args:
beamform_weights (Tensor): The complex-valued beamforming weight matrix.
Tensor of dimension `(..., freq, channel)`
specgram (Tensor): The multi-channel complex-valued noisy spectrum.
Tensor of dimension `(..., channel, freq, time)`
Returns:
Tensor: The single-channel complex-valued enhanced spectrum.
Tensor of dimension `(..., freq, time)`
"""
# (..., freq, channel) x (..., channel, freq, time) -> (..., freq, time)
specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_weights.conj(), specgram])
return specgram_enhanced

0 comments on commit d8a8357

Please sign in to comment.