diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index cb1b7b96b28..7189e63dc9a 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -582,13 +582,17 @@ def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self): ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data) self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients) + def _apply_beamforming(self, beamform_weights, specgram): + specgram_enhanced = np.einsum("...fc,...cft->...ft", beamform_weights.conj(), specgram) + return specgram_enhanced + def test_apply_beamforming(self): channel = 4 n_fft_bin = 10 frame = 5 beamform_weights = np.random.random((n_fft_bin, channel)) + np.random.random((n_fft_bin, channel)) * 1j specgram = np.random.random((channel, n_fft_bin, frame)) + np.random.random((channel, n_fft_bin, frame)) * 1j - specgram_enhanced = np.einsum("...fc,...cft->...ft", beamform_weights.conj(), specgram) + specgram_enhanced = self._apply_beamforming(beamform_weights, specgram) specgram_enhanced_audio = F.apply_beamforming( torch.tensor(beamform_weights, dtype=self.complex_dtype, device=self.device), torch.tensor(specgram, dtype=self.complex_dtype, device=self.device),