diff --git a/docs/source/functional.rst b/docs/source/functional.rst index 417c192549..edec942731 100644 --- a/docs/source/functional.rst +++ b/docs/source/functional.rst @@ -266,6 +266,11 @@ rtf_power .. autofunction:: rtf_power +apply_beamforming +----------------- + +.. autofunction:: apply_beamforming + :hidden:`Loss` ~~~~~~~~~~~~~~ diff --git a/test/torchaudio_unittest/common_utils/beamform_utils.py b/test/torchaudio_unittest/common_utils/beamform_utils.py index 9f9d482910..96578dfa8e 100644 --- a/test/torchaudio_unittest/common_utils/beamform_utils.py +++ b/test/torchaudio_unittest/common_utils/beamform_utils.py @@ -70,3 +70,8 @@ def rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter): rtf = psd_n @ rtf rtf = rtf.squeeze(-1) return rtf + + +def apply_beamforming_numpy(beamform_weights, specgram): + specgram_enhanced = np.einsum("...fc,...cft->...ft", beamform_weights.conj(), specgram) + return specgram_enhanced diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index 32e1bda203..176d372f3b 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -9,6 +9,7 @@ from torchaudio_unittest.common_utils import ( TestBaseMixin, get_whitenoise, + get_spectrogram, rnnt_utils, ) @@ -333,6 +334,18 @@ def test_rtf_power_with_tensor(self, n_iter): reference_channel[0].fill_(1) self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter)) + def test_apply_beamforming(self): + torch.random.manual_seed(2434) + sr = 8000 + n_fft = 400 + batch_size, num_channels = 2, 3 + n_fft_bin = n_fft // 2 + 1 + x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=batch_size * num_channels) + specgram = get_spectrogram(x, n_fft=n_fft, hop_length=100) + specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1)) + beamform_weights = torch.rand(n_fft_bin, num_channels, dtype=torch.cfloat) + self.assert_grad(F.apply_beamforming, (beamform_weights, specgram)) + class AutogradFloat32(TestBaseMixin): def assert_grad( diff --git a/test/torchaudio_unittest/functional/batch_consistency_test.py b/test/torchaudio_unittest/functional/batch_consistency_test.py index 55c98d9120..355a5c610d 100644 --- a/test/torchaudio_unittest/functional/batch_consistency_test.py +++ b/test/torchaudio_unittest/functional/batch_consistency_test.py @@ -415,3 +415,15 @@ def test_rtf_power_with_tensor(self, n_iter): } func = partial(F.rtf_power, **kwargs) self.assert_batch_consistency(func, (psd_speech, psd_noise, reference_channel)) + + def test_apply_beamforming(self): + torch.random.manual_seed(2434) + sr = 8000 + n_fft = 400 + batch_size, num_channels = 2, 3 + n_fft_bin = n_fft // 2 + 1 + x = common_utils.get_whitenoise(sample_rate=sr, duration=0.05, n_channels=batch_size * num_channels) + specgram = common_utils.get_spectrogram(x, n_fft=n_fft, hop_length=100) + specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1)) + beamform_weights = torch.rand(batch_size, n_fft_bin, num_channels, dtype=torch.cfloat) + self.assert_batch_consistency(F.apply_beamforming, (beamform_weights, specgram)) diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index 0508ea3d91..e22828aea0 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -795,6 +795,27 @@ def test_rtf_power_with_tensor(self, n_iter): ) self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio) + def test_apply_beamforming(self): + """Verify ``F.apply_beamforming`` method by numpy implementation. + Given the multi-channel complex-valued spectrum and complex-valued + beamforming weights (Tensor of dimension `(..., freq, channel)`) as inputs, + ``F.apply_beamforming`` outputs the single-channel complex-valued enhanced + spectrum, which should be identical to the output of ``apply_beamforming_numpy``. + """ + channel = 4 + n_fft_bin = 10 + frame = 5 + beamform_weights = np.random.random((n_fft_bin, channel)) + np.random.random((n_fft_bin, channel)) * 1j + specgram = np.random.random((channel, n_fft_bin, frame)) + np.random.random((channel, n_fft_bin, frame)) * 1j + specgram_enhanced = beamform_utils.apply_beamforming_numpy(beamform_weights, specgram) + specgram_enhanced_audio = F.apply_beamforming( + torch.tensor(beamform_weights, dtype=self.complex_dtype, device=self.device), + torch.tensor(specgram, dtype=self.complex_dtype, device=self.device), + ) + self.assertEqual( + torch.tensor(specgram_enhanced, dtype=self.complex_dtype, device=self.device), specgram_enhanced_audio + ) + class FunctionalCPUOnly(TestBaseMixin): def test_melscale_fbanks_no_warning_high_n_freq(self): diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index faf682ac09..f1cdbd3629 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -727,6 +727,14 @@ def test_rtf_power_with_tensor(self, n_iter): reference_channel[..., 0].fill_(1) self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter)) + def test_apply_beamforming(self): + num_channels = 4 + n_fft_bin = 201 + num_frames = 10 + beamform_weights = torch.rand(n_fft_bin, num_channels, dtype=self.complex_dtype, device=self.device) + specgram = torch.rand(num_channels, n_fft_bin, num_frames, dtype=self.complex_dtype, device=self.device) + self._assert_consistency_complex(F.apply_beamforming, (beamform_weights, specgram)) + class FunctionalFloat32Only(TestBaseMixin): def test_rnnt_loss(self): diff --git a/torchaudio/functional/__init__.py b/torchaudio/functional/__init__.py index 67c877ff94..1e8e0a89c1 100644 --- a/torchaudio/functional/__init__.py +++ b/torchaudio/functional/__init__.py @@ -51,6 +51,7 @@ mvdr_weights_rtf, rtf_evd, rtf_power, + apply_beamforming, ) __all__ = [ @@ -104,4 +105,5 @@ "mvdr_weights_rtf", "rtf_evd", "rtf_power", + "apply_beamforming", ] diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 4c47df5c91..5826f8feee 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -42,6 +42,7 @@ "mvdr_weights_rtf", "rtf_evd", "rtf_power", + "apply_beamforming", ] @@ -1886,3 +1887,26 @@ def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor # which is psd_n @ phi @ ref_channel rtf = torch.matmul(psd_n, rtf) return rtf.squeeze(-1) + + +def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor: + r"""Apply the beamforming weight to the multi-channel noisy spectrum to obtain the single-channel enhanced spectrum. + + .. math:: + \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f) + where :math:`\textbf{w}_{\text{bf}}(f)` is the beamforming weight for the :math:`f`-th frequency bin, + :math:`\textbf{Y}` is the multi-channel spectrum for the :math:`f`-th frequency bin. + + Args: + beamform_weights (Tensor): The complex-valued beamforming weight matrix. + Tensor of dimension `(..., freq, channel)` + specgram (Tensor): The multi-channel complex-valued noisy spectrum. + Tensor of dimension `(..., channel, freq, time)` + + Returns: + Tensor: The single-channel complex-valued enhanced spectrum. + Tensor of dimension `(..., freq, time)` + """ + # (..., freq, channel) x (..., channel, freq, time) -> (..., freq, time) + specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_weights.conj(), specgram]) + return specgram_enhanced