Skip to content

Commit

Permalink
add docstring for functional_impl test
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Feb 25, 2022
1 parent 8cd7e13 commit 14937a8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
14 changes: 14 additions & 0 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,13 @@ def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self):
self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients)

def test_mvdr_weights_rtf(self):
"""Verify ``F.mvdr_weights_rtf`` method by numpy implementation.
Given the relative transfer function (RTF) of target speech (Tensor of dimension `(..., freq, channel)`),
the PSD matrix of noise (Tensor of dimension `(..., freq, channel, channel)`), and an integer
indicating the reference channel as inputs, ``F.mvdr_weights_rtf`` outputs the mvdr weights
(Tensor of dimension `(..., freq, channel)`), which should be close to the output of
``mvdr_weights_rtf_numpy``.
"""
n_fft_bin = 10
channel = 4
reference_channel = 0
Expand All @@ -603,6 +610,13 @@ def test_mvdr_weights_rtf(self):
)

def test_mvdr_weights_rtf_with_tensor(self):
"""Verify ``F.mvdr_weights_rtf`` method by numpy implementation.
Given the relative transfer function (RTF) of target speech (Tensor of dimension `(..., freq, channel)`),
the PSD matrix of noise (Tensor of dimension `(..., freq, channel, channel)`), and a one-hot Tensor
indicating the reference channel as inputs, ``F.mvdr_weights_rtf`` outputs the mvdr weights
(Tensor of dimension `(..., freq, channel)`), which should be close to the output of
``mvdr_weights_rtf_numpy``.
"""
n_fft_bin = 10
channel = 4
reference_channel = np.zeros(channel)
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,7 @@ def mvdr_weights_rtf(
reference_channel = reference_channel.to(psd_n.dtype)
scale = torch.einsum("...c,...c->...", [rtf.conj(), reference_channel[..., None, :]])
else:
raise TypeError(f"Unsupported dtype for reference_channel. Found: {type(reference_channel)}.")
raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")

beamform_weights = beamform_weights * scale[..., None]

Expand Down

0 comments on commit 14937a8

Please sign in to comment.