-
Notifications
You must be signed in to change notification settings - Fork 662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mvdr_weights_rtf to torchaudio.functional #2229
Conversation
9fa3b44
to
6500986
Compare
cc53427
to
e47e54c
Compare
@@ -582,6 +583,45 @@ def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self): | |||
ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data) | |||
self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients) | |||
|
|||
def test_mvdr_weights_rtf(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring please.
torchaudio/functional/functional.py
Outdated
reference_channel = reference_channel.to(psd_n.dtype) | ||
scale = torch.einsum("...c,...c->...", [rtf.conj(), reference_channel[..., None, :]]) | ||
else: | ||
raise TypeError(f"Unsupported dtype for reference_channel. Found: {type(reference_channel)}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to https://github.com/pytorch/audio/pull/2231/files#r814462176, please write what is expected.
@nateanl has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary: This PR adds ``mvdr_weights_rtf`` method to ``torchaudio.functional``. It computes the MVDR weight matrix based on the solution that applies relative transfer function (RTF). See [the paper](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.725.673&rep=rep1&type=pdf) for the reference. The input arguments are the complex-valued RTF Tensor of the target speech, power spectral density (PSD) matrix of noise, int or one-hot Tensor to indicate the reference channel, respectively. Pull Request resolved: pytorch#2229 Reviewed By: mthrok Differential Revision: D34474119 Pulled By: nateanl fbshipit-source-id: ca20eca4d071ebb99f0b6827613338796f3ec2a2
14937a8
to
dc19b0d
Compare
This pull request was exported from Phabricator. Differential Revision: D34474119 |
Summary: This PR adds ``mvdr_weights_rtf`` method to ``torchaudio.functional``. It computes the MVDR weight matrix based on the solution that applies relative transfer function (RTF). See [the paper](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.725.673&rep=rep1&type=pdf) for the reference. The input arguments are the complex-valued RTF Tensor of the target speech, power spectral density (PSD) matrix of noise, int or one-hot Tensor to indicate the reference channel, respectively. Pull Request resolved: pytorch#2229 Reviewed By: mthrok Differential Revision: D34474119 Pulled By: nateanl fbshipit-source-id: 2d6f62cd0858f29ed6e4e03c23dcc11c816204e2
This PR adds
mvdr_weights_rtf
method totorchaudio.functional
.It computes the MVDR weight matrix based on the solution that applies relative transfer function (RTF). See the paper for the reference.
The input arguments are the complex-valued RTF Tensor of the target speech, power spectral density (PSD) matrix of noise, int or one-hot Tensor to indicate the reference channel, respectively.