Skip to content

Commit

Permalink
move numpy method to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Feb 21, 2022
1 parent 197b94e commit 8cd7e13
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
19 changes: 19 additions & 0 deletions test/torchaudio_unittest/common_utils/beamform_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np


def mvdr_weights_rtf_numpy(rtf, psd_n, reference_channel, diag_eps=1e-7, eps=1e-8):
channel = rtf.shape[-1]
eye = np.eye(channel)
trace = np.matrix.trace(psd_n, axis1=1, axis2=2)
epsilon = trace.real[..., None, None] * diag_eps + eps
diag = epsilon * eye[..., :, :]
psd_n = psd_n + diag
numerator = np.linalg.solve(psd_n, np.expand_dims(rtf, -1)).squeeze(-1)
denominator = np.einsum("...d,...d->...", rtf.conj(), numerator)
beamform_weights = numerator / (np.expand_dims(denominator.real, -1) + eps)
if isinstance(reference_channel, int):
scale = rtf[..., reference_channel].conj()
else:
scale = np.einsum("...c,...c->...", rtf.conj(), reference_channel[..., None, :])
beamform_weights = beamform_weights * scale[..., None]
return beamform_weights
22 changes: 3 additions & 19 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
nested_params,
get_whitenoise,
rnnt_utils,
beamform_utils,
)


Expand Down Expand Up @@ -582,30 +583,13 @@ def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self):
ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients)

def _mvdr_weights_rtf_numpy(self, rtf, psd_n, reference_channel, diag_eps=1e-7, eps=1e-8):
channel = rtf.shape[-1]
eye = np.eye(channel)
trace = np.matrix.trace(psd_n, axis1=1, axis2=2)
epsilon = trace.real[..., None, None] * diag_eps + eps
diag = epsilon * eye[..., :, :]
psd_n = psd_n + diag
numerator = np.linalg.solve(psd_n, np.expand_dims(rtf, -1)).squeeze(-1)
denominator = np.einsum("...d,...d->...", rtf.conj(), numerator)
beamform_weights = numerator / (np.expand_dims(denominator.real, -1) + eps)
if isinstance(reference_channel, int):
scale = rtf[..., reference_channel].conj()
else:
scale = np.einsum("...c,...c->...", rtf.conj(), reference_channel[..., None, :])
beamform_weights = beamform_weights * scale[..., None]
return beamform_weights

def test_mvdr_weights_rtf(self):
n_fft_bin = 10
channel = 4
reference_channel = 0
rtf = np.random.random((n_fft_bin, channel)) + np.random.random((n_fft_bin, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
beamform_weights = self._mvdr_weights_rtf_numpy(rtf, psd_n, reference_channel)
beamform_weights = beamform_utils.mvdr_weights_rtf_numpy(rtf, psd_n, reference_channel)
beamform_weights_audio = F.mvdr_weights_rtf(
torch.tensor(rtf, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
Expand All @@ -625,7 +609,7 @@ def test_mvdr_weights_rtf_with_tensor(self):
reference_channel[0] = 1
rtf = np.random.random((n_fft_bin, channel)) + np.random.random((n_fft_bin, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
beamform_weights = self._mvdr_weights_rtf_numpy(rtf, psd_n, reference_channel)
beamform_weights = beamform_utils.mvdr_weights_rtf_numpy(rtf, psd_n, reference_channel)
beamform_weights_audio = F.mvdr_weights_rtf(
torch.tensor(rtf, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
Expand Down

0 comments on commit 8cd7e13

Please sign in to comment.