Skip to content

Commit

Permalink
rename method to mvdr_weights_rtf
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Feb 17, 2022
1 parent 6500986 commit cc53427
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 25 deletions.
4 changes: 2 additions & 2 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,10 @@ treble_biquad

.. autofunction:: spectral_centroid

compute_mvdr_weights_rtf
mvdr_weights_rtf
------------------------

.. autofunction:: compute_mvdr_weights_rtf
.. autofunction:: mvdr_weights_rtf

:hidden:`Loss`
~~~~~~~~~~~~~~
Expand Down
8 changes: 4 additions & 4 deletions test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,16 @@ def test_bandreject_biquad(self, central_freq, Q):
Q = torch.tensor(Q)
self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))

def test_compute_mvdr_weights_rtf(self):
def test_mvdr_weights_rtf(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
n_fft_bin = 10
rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
self.assert_grad(F.compute_mvdr_weights_rtf, (rtf, psd_noise, 0))
self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, 0))

def test_compute_mvdr_weights_rtf_with_tensor(self):
def test_mvdr_weights_rtf_with_tensor(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
Expand All @@ -268,7 +268,7 @@ def test_compute_mvdr_weights_rtf_with_tensor(self):
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = torch.zeros(batch_size, channel)
reference_channel[..., 0].fill_(1)
self.assert_grad(F.compute_mvdr_weights_rtf, (rtf, psd_noise, reference_channel))
self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel))


class AutogradFloat32(TestBaseMixin):
Expand Down
8 changes: 4 additions & 4 deletions test/torchaudio_unittest/functional/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def test_filtfilt(self):
b = torch.rand(self.batch_size, 3)
self.assert_batch_consistency(F.filtfilt, inputs=(x, a, b))

def test_compute_mvdr_weights_rtf(self):
def test_mvdr_weights_rtf(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
Expand All @@ -305,10 +305,10 @@ def test_compute_mvdr_weights_rtf(self):
kwargs = {
"reference_channel": 0,
}
func = partial(F.compute_mvdr_weights_rtf, **kwargs)
func = partial(F.mvdr_weights_rtf, **kwargs)
self.assert_batch_consistency(func, (rtf, psd_noise))

def test_compute_mvdr_weights_rtf_with_tensor(self):
def test_mvdr_weights_rtf_with_tensor(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
Expand All @@ -317,4 +317,4 @@ def test_compute_mvdr_weights_rtf_with_tensor(self):
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(batch_size, channel)
reference_channel[..., 0].fill_(1)
self.assert_batch_consistency(F.compute_mvdr_weights_rtf, (rtf, psd_noise, reference_channel))
self.assert_batch_consistency(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel))
Original file line number Diff line number Diff line change
Expand Up @@ -617,19 +617,19 @@ def test_phase_vocoder(self):
)[..., None]
self._assert_consistency_complex(F.phase_vocoder, (tensor, rate, phase_advance))

def test_compute_mvdr_weights_rtf(self):
def test_mvdr_weights_rtf(self):
def func(rtf, psd_noise, reference_channel: int):
return F.compute_mvdr_weights_rtf(rtf, psd_noise, reference_channel)
return F.mvdr_weights_rtf(rtf, psd_noise, reference_channel)

channel = 4
n_fft_bin = 10
rtf = torch.rand(n_fft_bin, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
self._assert_consistency_complex(func, (rtf, psd_noise, 0))

def test_compute_mvdr_weights_rtf_with_tensor(self):
def test_mvdr_weights_rtf_with_tensor(self):
def func(rtf, psd_noise, reference_channel: torch.Tensor):
return F.compute_mvdr_weights_rtf(rtf, psd_noise, reference_channel)
return F.mvdr_weights_rtf(rtf, psd_noise, reference_channel)

channel = 4
n_fft_bin = 10
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
edit_distance,
pitch_shift,
rnnt_loss,
compute_mvdr_weights_rtf,
mvdr_weights_rtf,
)

__all__ = [
Expand Down Expand Up @@ -95,5 +95,5 @@
"edit_distance",
"pitch_shift",
"rnnt_loss",
"compute_mvdr_weights_rtf",
"mvdr_weights_rtf",
]
19 changes: 10 additions & 9 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,11 +1674,10 @@ def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.T
return mat


def compute_mvdr_weights_rtf(
def mvdr_weights_rtf(
rtf: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
normalize: bool = True,
reference_channel: Optional[Union[int, Tensor]] = None,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
Expand All @@ -1698,13 +1697,15 @@ def compute_mvdr_weights_rtf(
Tensor of dimension `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
reference_channel (int or Tensor, optional): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`.
normalize (bool, optional): whether to normalize the RTF vector. (Default: ``True``)
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
is one-hot.
If a non-None value is given, the MVDR weights will be normalized by ``rtf[..., reference_channel].conj()``
(Default: ``None``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading
(Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``)
Expand All @@ -1718,8 +1719,8 @@ def compute_mvdr_weights_rtf(
# denominator = stv^H @ psd_n.inv() @ stv
denominator = torch.einsum("...d,...d->...", [rtf.conj(), numerator])
beamform_weights = numerator / (denominator.real.unsqueeze(-1) + eps)
# normalzie the numerator
if normalize:
# normalize the numerator
if reference_channel is not None:
if isinstance(reference_channel, int):
scale = rtf[..., reference_channel].conj()
else:
Expand Down

0 comments on commit cc53427

Please sign in to comment.