Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add diagonal_loading optional to rtf_power #2369

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion test/torchaudio_unittest/common_utils/beamform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ def rtf_evd_numpy(psd):
return rtf


def rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter):
def rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter, diagonal_loading=True, diag_eps=1e-7, eps=1e-8):
if diagonal_loading:
channel = psd_s.shape[-1]
eye = np.eye(channel)
trace = np.matrix.trace(psd_n, axis1=1, axis2=2)
epsilon = trace.real[..., None, None] * diag_eps + eps
diag = epsilon * eye[..., :, :]
psd_n = psd_n + diag
phi = np.linalg.solve(psd_n, psd_s)
if isinstance(reference_channel, int):
rtf = phi[..., reference_channel]
Expand Down
16 changes: 8 additions & 8 deletions test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,33 +333,33 @@ def test_mvdr_weights_rtf_with_tensor(self):

@parameterized.expand(
[
(1,),
(3,),
(1, True),
(3, False),
]
)
def test_rtf_power(self, n_iter):
def test_rtf_power(self, n_iter, diagonal_loading):
torch.random.manual_seed(2434)
channel = 4
n_fft_bin = 5
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, 0, n_iter))
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, 0, n_iter, diagonal_loading))

@parameterized.expand(
[
(1,),
(3,),
(1, True),
(3, False),
]
)
def test_rtf_power_with_tensor(self, n_iter):
def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
torch.random.manual_seed(2434)
channel = 4
n_fft_bin = 5
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(channel)
reference_channel[0].fill_(1)
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading))

def test_apply_beamforming(self):
torch.random.manual_seed(2434)
Expand Down
22 changes: 12 additions & 10 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,12 +740,12 @@ def test_rtf_evd(self):

@parameterized.expand(
[
(1,),
(2,),
(3,),
(1, True),
(2, False),
(3, True),
]
)
def test_rtf_power(self, n_iter):
def test_rtf_power(self, n_iter, diagonal_loading):
"""Verify ``F.rtf_power`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
an integer indicating the reference channel, and an integer for number of iterations, ``F.rtf_power``
Expand All @@ -757,23 +757,24 @@ def test_rtf_power(self, n_iter):
reference_channel = 0
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter)
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter, diagonal_loading)
rtf_audio = F.rtf_power(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
reference_channel,
n_iter,
diagonal_loading=diagonal_loading,
)
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)

@parameterized.expand(
[
(1,),
(2,),
(3,),
(1, True),
(2, False),
(3, True),
]
)
def test_rtf_power_with_tensor(self, n_iter):
def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
"""Verify ``F.rtf_power`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
a one-hot Tensor indicating the reference channel, and an integer for number of iterations, ``F.rtf_power``
Expand All @@ -786,12 +787,13 @@ def test_rtf_power_with_tensor(self, n_iter):
reference_channel[0] = 1
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter)
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter, diagonal_loading)
rtf_audio = F.rtf_power(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
torch.tensor(reference_channel, dtype=self.dtype, device=self.device),
n_iter,
diagonal_loading=diagonal_loading,
)
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,32 +700,38 @@ def test_rtf_evd(self):

@parameterized.expand(
[
(1,),
(3,),
(1, True),
(3, False),
]
)
def test_rtf_power(self, n_iter):
def test_rtf_power(self, n_iter, diagonal_loading):
channel = 4
n_fft_bin = 10
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = 0
self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))
diag_eps = 1e-7
self._assert_consistency_complex(
F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading, diag_eps)
)

@parameterized.expand(
[
(1,),
(3,),
(1, True),
(3, False),
]
)
def test_rtf_power_with_tensor(self, n_iter):
def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
channel = 4
n_fft_bin = 10
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = torch.zeros(channel)
reference_channel[..., 0].fill_(1)
self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))
diag_eps = 1e-7
self._assert_consistency_complex(
F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading, diag_eps)
)

def test_apply_beamforming(self):
num_channels = 4
Expand Down
35 changes: 24 additions & 11 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1958,29 +1958,42 @@ def rtf_evd(psd_s: Tensor) -> Tensor:
return rtf


def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor], n_iter: int = 3) -> Tensor:
def rtf_power(
psd_s: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
n_iter: int = 3,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
) -> Tensor:
r"""Estimate the relative transfer function (RTF) or the steering vector by the power method.

.. devices:: CPU CUDA

.. properties:: Autograd TorchScript

Args:
psd_s (Tensor): The complex-valued covariance matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): The complex-valued covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor with dimensions `(..., freq, channel, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
n_iter (int): number of iterations in power method. (Default: ``3``)
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)

Returns:
Tensor: the estimated complex-valued RTF of target speech
Tensor of dimension `(..., freq, channel)`
torch.Tensor: The estimated complex-valued RTF of target speech.
Tensor of dimension `(..., freq, channel)`.
"""
assert n_iter > 0, "The number of iteration must be greater than 0."
# Apply diagonal loading to psd_n to improve robustness.
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps)
# phi is regarded as the first iteration
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
if torch.jit.isinstance(reference_channel, int):
Expand Down