Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add diagonal_loading optional to rtf_power #2369

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion test/torchaudio_unittest/common_utils/beamform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ def rtf_evd_numpy(psd):
return rtf


def rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter):
def rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter, diagonal_loading=True, diag_eps=1e-7, eps=1e-8):
if diagonal_loading:
channel = psd_s.shape[-1]
eye = np.eye(channel)
trace = np.matrix.trace(psd_n, axis1=1, axis2=2)
epsilon = trace.real[..., None, None] * diag_eps + eps
diag = epsilon * eye[..., :, :]
psd_n = psd_n + diag
phi = np.linalg.solve(psd_n, psd_s)
if isinstance(reference_channel, int):
rtf = phi[..., reference_channel]
Expand Down
11 changes: 6 additions & 5 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,12 +740,12 @@ def test_rtf_evd(self):

@parameterized.expand(
[
(1,),
(2,),
(3,),
(1, True),
(2, False),
(3, True),
]
)
def test_rtf_power(self, n_iter):
def test_rtf_power(self, n_iter, diagonal_loading):
"""Verify ``F.rtf_power`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
an integer indicating the reference channel, and an integer for number of iterations, ``F.rtf_power``
Expand All @@ -757,12 +757,13 @@ def test_rtf_power(self, n_iter):
reference_channel = 0
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter)
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter, diagonal_loading)
rtf_audio = F.rtf_power(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
reference_channel,
n_iter,
diagonal_loading=diagonal_loading,
)
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,32 +700,40 @@ def test_rtf_evd(self):

@parameterized.expand(
[
(1,),
(3,),
(1, True),
(3, False),
]
)
def test_rtf_power(self, n_iter):
def test_rtf_power(self, n_iter, diagonal_loading):
channel = 4
n_fft_bin = 10
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = 0
self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))
diag_eps = 1e-7
eps = 1e-8
self._assert_consistency_complex(
F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading, diag_eps, eps)
)

@parameterized.expand(
[
(1,),
(3,),
(1, True),
(3, False),
]
)
def test_rtf_power_with_tensor(self, n_iter):
def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
channel = 4
n_fft_bin = 10
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = torch.zeros(channel)
reference_channel[..., 0].fill_(1)
self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))
diag_eps = 1e-7
eps = 1e-8
self._assert_consistency_complex(
F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading, diag_eps, eps)
)

def test_apply_beamforming(self):
num_channels = 4
Expand Down
18 changes: 17 additions & 1 deletion torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1958,7 +1958,15 @@ def rtf_evd(psd_s: Tensor) -> Tensor:
return rtf


def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor], n_iter: int = 3) -> Tensor:
def rtf_power(
psd_s: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
n_iter: int = 3,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> Tensor:
r"""Estimate the relative transfer function (RTF) or the steering vector by the power method.

.. devices:: CPU CUDA
Expand All @@ -1975,12 +1983,20 @@ def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
is one-hot.
n_iter (int): number of iterations in power method. (Default: ``3``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading
(Default: ``1e-7``)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this docstring be more helpful? does a value added to the denominator in the beamforming weight computation. from #2368 make sense here, and is it worth adding that this is only used for the case when diagonal_loading=True?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I will align the eps docstring in the functions and modules.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The eps here is for diagonal loading, which is confusing with eps in computing beamforming weight. I decided to exclude it from the API and use the default value in _tik_reg.


Returns:
Tensor: the estimated complex-valued RTF of target speech
Tensor of dimension `(..., freq, channel)`
"""
assert n_iter > 0, "The number of iteration must be greater than 0."
# Apply diagonal loading to psd_n to improve robustness.
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
# phi is regarded as the first iteration
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
if torch.jit.isinstance(reference_channel, int):
Expand Down