Skip to content

Commit

Permalink
refactor test
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Feb 18, 2022
1 parent dfebf0a commit d72cc84
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,13 +582,17 @@ def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self):
ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients)

def _apply_beamforming(self, beamform_weights, specgram):
specgram_enhanced = np.einsum("...fc,...cft->...ft", beamform_weights.conj(), specgram)
return specgram_enhanced

def test_apply_beamforming(self):
channel = 4
n_fft_bin = 10
frame = 5
beamform_weights = np.random.random((n_fft_bin, channel)) + np.random.random((n_fft_bin, channel)) * 1j
specgram = np.random.random((channel, n_fft_bin, frame)) + np.random.random((channel, n_fft_bin, frame)) * 1j
specgram_enhanced = np.einsum("...fc,...cft->...ft", beamform_weights.conj(), specgram)
specgram_enhanced = self._apply_beamforming(beamform_weights, specgram)
specgram_enhanced_audio = F.apply_beamforming(
torch.tensor(beamform_weights, dtype=self.complex_dtype, device=self.device),
torch.tensor(specgram, dtype=self.complex_dtype, device=self.device),
Expand Down

0 comments on commit d72cc84

Please sign in to comment.