Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add apply_beamforming to torchaudio.functional #2232

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ rtf_power

.. autofunction:: rtf_power

apply_beamforming
-----------------

.. autofunction:: apply_beamforming

:hidden:`Loss`
~~~~~~~~~~~~~~

Expand Down
5 changes: 5 additions & 0 deletions test/torchaudio_unittest/common_utils/beamform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,8 @@ def rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter):
rtf = psd_n @ rtf
rtf = rtf.squeeze(-1)
return rtf


def apply_beamforming_numpy(beamform_weights, specgram):
specgram_enhanced = np.einsum("...fc,...cft->...ft", beamform_weights.conj(), specgram)
return specgram_enhanced
13 changes: 13 additions & 0 deletions test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
get_spectrogram,
rnnt_utils,
)

Expand Down Expand Up @@ -333,6 +334,18 @@ def test_rtf_power_with_tensor(self, n_iter):
reference_channel[0].fill_(1)
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))

def test_apply_beamforming(self):
torch.random.manual_seed(2434)
sr = 8000
n_fft = 400
batch_size, num_channels = 2, 3
n_fft_bin = n_fft // 2 + 1
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=batch_size * num_channels)
specgram = get_spectrogram(x, n_fft=n_fft, hop_length=100)
specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1))
beamform_weights = torch.rand(n_fft_bin, num_channels, dtype=torch.cfloat)
self.assert_grad(F.apply_beamforming, (beamform_weights, specgram))


class AutogradFloat32(TestBaseMixin):
def assert_grad(
Expand Down
12 changes: 12 additions & 0 deletions test/torchaudio_unittest/functional/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,15 @@ def test_rtf_power_with_tensor(self, n_iter):
}
func = partial(F.rtf_power, **kwargs)
self.assert_batch_consistency(func, (psd_speech, psd_noise, reference_channel))

def test_apply_beamforming(self):
torch.random.manual_seed(2434)
sr = 8000
n_fft = 400
batch_size, num_channels = 2, 3
n_fft_bin = n_fft // 2 + 1
x = common_utils.get_whitenoise(sample_rate=sr, duration=0.05, n_channels=batch_size * num_channels)
specgram = common_utils.get_spectrogram(x, n_fft=n_fft, hop_length=100)
specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1))
beamform_weights = torch.rand(batch_size, n_fft_bin, num_channels, dtype=torch.cfloat)
self.assert_batch_consistency(F.apply_beamforming, (beamform_weights, specgram))
21 changes: 21 additions & 0 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,27 @@ def test_rtf_power_with_tensor(self, n_iter):
)
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)

def test_apply_beamforming(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests in functional_impl.py implements test logics specific to each tested modules, so please add docstring of what is the expectation of this test.

test docstring should tell future maintainers (without context) what it is testing, so often it is the form of "(under this condition), given this input, this out comes should happen".

"""Verify ``F.apply_beamforming`` method by numpy implementation.
Given the multi-channel complex-valued spectrum and complex-valued
beamforming weights (Tensor of dimension `(..., freq, channel)`) as inputs,
``F.apply_beamforming`` outputs the single-channel complex-valued enhanced
spectrum, which should be identical to the output of ``apply_beamforming_numpy``.
"""
channel = 4
n_fft_bin = 10
frame = 5
beamform_weights = np.random.random((n_fft_bin, channel)) + np.random.random((n_fft_bin, channel)) * 1j
specgram = np.random.random((channel, n_fft_bin, frame)) + np.random.random((channel, n_fft_bin, frame)) * 1j
specgram_enhanced = beamform_utils.apply_beamforming_numpy(beamform_weights, specgram)
specgram_enhanced_audio = F.apply_beamforming(
torch.tensor(beamform_weights, dtype=self.complex_dtype, device=self.device),
torch.tensor(specgram, dtype=self.complex_dtype, device=self.device),
)
self.assertEqual(
torch.tensor(specgram_enhanced, dtype=self.complex_dtype, device=self.device), specgram_enhanced_audio
)


class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,14 @@ def test_rtf_power_with_tensor(self, n_iter):
reference_channel[..., 0].fill_(1)
self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))

def test_apply_beamforming(self):
num_channels = 4
n_fft_bin = 201
num_frames = 10
beamform_weights = torch.rand(n_fft_bin, num_channels, dtype=self.complex_dtype, device=self.device)
specgram = torch.rand(num_channels, n_fft_bin, num_frames, dtype=self.complex_dtype, device=self.device)
self._assert_consistency_complex(F.apply_beamforming, (beamform_weights, specgram))


class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
mvdr_weights_rtf,
rtf_evd,
rtf_power,
apply_beamforming,
)

__all__ = [
Expand Down Expand Up @@ -104,4 +105,5 @@
"mvdr_weights_rtf",
"rtf_evd",
"rtf_power",
"apply_beamforming",
]
24 changes: 24 additions & 0 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"mvdr_weights_rtf",
"rtf_evd",
"rtf_power",
"apply_beamforming",
]


Expand Down Expand Up @@ -1886,3 +1887,26 @@ def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor
# which is psd_n @ phi @ ref_channel
rtf = torch.matmul(psd_n, rtf)
return rtf.squeeze(-1)


def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
r"""Apply the beamforming weight to the multi-channel noisy spectrum to obtain the single-channel enhanced spectrum.

.. math::
nateanl marked this conversation as resolved.
Show resolved Hide resolved
\hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
where :math:`\textbf{w}_{\text{bf}}(f)` is the beamforming weight for the :math:`f`-th frequency bin,
:math:`\textbf{Y}` is the multi-channel spectrum for the :math:`f`-th frequency bin.

Args:
beamform_weights (Tensor): The complex-valued beamforming weight matrix.
Tensor of dimension `(..., freq, channel)`
specgram (Tensor): The multi-channel complex-valued noisy spectrum.
Tensor of dimension `(..., channel, freq, time)`

Returns:
Tensor: The single-channel complex-valued enhanced spectrum.
Tensor of dimension `(..., freq, time)`
"""
# (..., freq, channel) x (..., channel, freq, time) -> (..., freq, time)
specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_weights.conj(), specgram])
return specgram_enhanced