From 62a859dd611c428880fb3e807fe1bf751df08d3c Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Mon, 21 Feb 2022 14:09:07 +0000 Subject: [PATCH] move numpy method to utils --- test/torchaudio_unittest/common_utils/beamform_utils.py | 6 ++++++ test/torchaudio_unittest/functional/functional_impl.py | 7 ++----- 2 files changed, 8 insertions(+), 5 deletions(-) create mode 100644 test/torchaudio_unittest/common_utils/beamform_utils.py diff --git a/test/torchaudio_unittest/common_utils/beamform_utils.py b/test/torchaudio_unittest/common_utils/beamform_utils.py new file mode 100644 index 00000000000..99b167e51c6 --- /dev/null +++ b/test/torchaudio_unittest/common_utils/beamform_utils.py @@ -0,0 +1,6 @@ +import numpy as np + + +def apply_beamforming_numpy(beamform_weights, specgram): + specgram_enhanced = np.einsum("...fc,...cft->...ft", beamform_weights.conj(), specgram) + return specgram_enhanced diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index 7189e63dc9a..89eeb94ed88 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -14,6 +14,7 @@ nested_params, get_whitenoise, rnnt_utils, + beamform_utils, ) @@ -582,17 +583,13 @@ def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self): ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data) self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients) - def _apply_beamforming(self, beamform_weights, specgram): - specgram_enhanced = np.einsum("...fc,...cft->...ft", beamform_weights.conj(), specgram) - return specgram_enhanced - def test_apply_beamforming(self): channel = 4 n_fft_bin = 10 frame = 5 beamform_weights = np.random.random((n_fft_bin, channel)) + np.random.random((n_fft_bin, channel)) * 1j specgram = np.random.random((channel, n_fft_bin, frame)) + np.random.random((channel, n_fft_bin, frame)) * 1j - specgram_enhanced = self._apply_beamforming(beamform_weights, specgram) + specgram_enhanced = beamform_utils.apply_beamforming_numpy(beamform_weights, specgram) specgram_enhanced_audio = F.apply_beamforming( torch.tensor(beamform_weights, dtype=self.complex_dtype, device=self.device), torch.tensor(specgram, dtype=self.complex_dtype, device=self.device),