diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index 8753bbad140..dfbb506d104 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -584,6 +584,13 @@ def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self): self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients) def test_mvdr_weights_rtf(self): + """Verify ``F.mvdr_weights_rtf`` method by numpy implementation. + Given the relative transfer function (RTF) of target speech (Tensor of dimension `(..., freq, channel)`), + the PSD matrix of noise (Tensor of dimension `(..., freq, channel, channel)`), and an integer + indicating the reference channel as inputs, ``F.mvdr_weights_rtf`` outputs the mvdr weights + (Tensor of dimension `(..., freq, channel)`), which should be close to the output of + ``mvdr_weights_rtf_numpy``. + """ n_fft_bin = 10 channel = 4 reference_channel = 0 @@ -603,6 +610,13 @@ def test_mvdr_weights_rtf(self): ) def test_mvdr_weights_rtf_with_tensor(self): + """Verify ``F.mvdr_weights_rtf`` method by numpy implementation. + Given the relative transfer function (RTF) of target speech (Tensor of dimension `(..., freq, channel)`), + the PSD matrix of noise (Tensor of dimension `(..., freq, channel, channel)`), and a one-hot Tensor + indicating the reference channel as inputs, ``F.mvdr_weights_rtf`` outputs the mvdr weights + (Tensor of dimension `(..., freq, channel)`), which should be close to the output of + ``mvdr_weights_rtf_numpy``. + """ n_fft_bin = 10 channel = 4 reference_channel = np.zeros(channel) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 28699eecf76..d6e8e8be7ab 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1727,7 +1727,7 @@ def mvdr_weights_rtf( reference_channel = reference_channel.to(psd_n.dtype) scale = torch.einsum("...c,...c->...", [rtf.conj(), reference_channel[..., None, :]]) else: - raise TypeError(f"Unsupported dtype for reference_channel. Found: {type(reference_channel)}.") + raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.") beamform_weights = beamform_weights * scale[..., None]