Skip to content

Commit

Permalink
refactor reference_channel argument
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Feb 6, 2022
1 parent da65d11 commit c5c04d1
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 38 deletions.
91 changes: 61 additions & 30 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
import warnings
from collections.abc import Sequence
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
import torchaudio
Expand Down Expand Up @@ -1077,18 +1077,18 @@ def sliding_window_cmn(
input_part = specgram[:, window_start : window_end - window_start, :]
cur_sum += torch.sum(input_part, 1)
if norm_vars:
cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
cur_sumsq += torch.cumsum(input_part**2, 1)[:, -1, :]
else:
if window_start > last_window_start:
frame_to_remove = specgram[:, last_window_start, :]
cur_sum -= frame_to_remove
if norm_vars:
cur_sumsq -= frame_to_remove ** 2
cur_sumsq -= frame_to_remove**2
if window_end > last_window_end:
frame_to_add = specgram[:, last_window_end, :]
cur_sum += frame_to_add
if norm_vars:
cur_sumsq += frame_to_add ** 2
cur_sumsq += frame_to_add**2
window_frames = window_end - window_start
last_window_start = window_start
last_window_end = window_end
Expand All @@ -1099,7 +1099,7 @@ def sliding_window_cmn(
else:
variance = cur_sumsq
variance = variance / window_frames
variance -= (cur_sum ** 2) / (window_frames ** 2)
variance -= (cur_sum**2) / (window_frames**2)
variance = torch.pow(variance, -0.5)
cmn_specgram[:, t, :] *= variance

Expand Down Expand Up @@ -1725,7 +1725,7 @@ def compute_power_spectral_density_matrix(
def compute_mvdr_weights_souden(
psd_s: Tensor,
psd_n: Tensor,
reference_vector: Tensor,
reference_channel: Union[int, Tensor],
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
Expand All @@ -1743,33 +1743,44 @@ def compute_mvdr_weights_souden(
:math:`\bf{u}` is an one-hot vector that represents the reference channel.
Args:
psd_s (torch.Tensor): psd matrix of target speech
psd_n (torch.Tensor): psd matrix of noise
reference_vector (torch.Tensor): one-hot reference channel matrix
psd_s (Tensor): The covariance matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): The covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`.
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
(Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``)
Returns:
torch.Tensor: the mvdr beamforming weight matrix of dimension (..., freq, channel).
Tensor: the mvdr beamforming weight matrix of dimension (..., freq, channel).
"""
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
if isinstance(reference_channel, int):
ref_vector = torch.zeros(
psd_n.size()[:-3] + psd_n.size()[-1:], device=psd_n.device, dtype=psd_n.dtype
) # (..., channel)
ref_vector[..., reference_channel].fill_(1)
else:
ref_vector = reference_channel
numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
# ws: (..., C, C) / (...,) -> (..., C, C)
ws = numerator / (_compute_mat_trace(numerator)[..., None, None] + eps)
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_weights = torch.einsum("...fec,...c->...fe", [ws, reference_vector])
beamform_weights = torch.einsum("...fec,...c->...fe", [ws, ref_vector])

return beamform_weights


def compute_mvdr_weights_rtf(
rtf: Tensor,
psd_n: Tensor,
reference_channel: int = 0,
reference_channel: Union[int, Tensor],
normalize: bool = True,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
Expand All @@ -1787,10 +1798,14 @@ def compute_mvdr_weights_rtf(
:math:`.^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
Args:
rtf (Tensor): RTF vector of target speech
psd_n (torch.Tensor): PSD matrix of noise
reference_channel (int):
normalize (bool):
rtf (Tensor): RTF vector of target speech.
Tensor of dimension `(..., freq, channel)`.
psd_n (torch.Tensor): The covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`.
normalize (bool, optional): whether to normalize the RTF vector. (Default: ``True``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
Expand All @@ -1802,15 +1817,22 @@ def compute_mvdr_weights_rtf(
"""
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
if isinstance(reference_channel, int):
ref_vector = torch.zeros(
psd_n.size()[:-3] + psd_n.size()[-1:], device=psd_n.device, dtype=psd_n.dtype
) # (..., channel)
ref_vector[..., reference_channel].fill_(1)
else:
ref_vector = reference_channel
# numerator = psd_n.inv() @ stv
numerator = torch.linalg.solve(psd_n, rtf).squeeze(-1) # (..., freq, channel)
# denominator = stv^H @ psd_n.inv() @ stv
denominator = torch.einsum("...d,...d->...", [rtf.conj().squeeze(-1), numerator])
beamform_weights = numerator / (denominator.real.unsqueeze(-1) + eps)
# normalzie the numerator
if normalize:
scale = rtf.squeeze(-1)[..., reference_channel, None].conj()
beamform_weights = beamform_weights * scale
scale = torch.einsum("...c,...c->...", [rtf.conj().squeeze(-1), ref_vector[..., None, :]])
beamform_weights = beamform_weights * scale[..., None]

return beamform_weights

Expand All @@ -1819,11 +1841,11 @@ def compute_rtf_evd(psd_s: Tensor) -> Tensor:
r"""Estimate the relative transfer function (RTF) or the steering vector by eigenvalue decomposition.
Args:
psd_s (Tensor): covariance matrix of target speech
psd_s (Tensor): The covariance matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
Returns:
Tensor: the estimated RTF of target speech
Tensor: The estimated RTF of target speech.
Tensor of dimension `(..., freq, channel, 1)`
"""
w, v = torch.linalg.eig(psd_s) # (..., freq, channel, channel)
Expand All @@ -1833,30 +1855,39 @@ def compute_rtf_evd(psd_s: Tensor) -> Tensor:
return rtf


def compute_rtf_power(psd_s: Tensor, psd_n: Tensor, reference_vector: Tensor) -> Tensor:
def compute_rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor]) -> Tensor:
r"""Estimate the relative transfer function (RTF) or the steering vector by the power method.
Args:
psd_s (Tensor): covariance matrix of speech
psd_s (Tensor): The covariance matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): covariance matrix of noise
psd_n (Tensor): The covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_vector (Tensor): one-hot reference channel vector
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`.
Returns:
Tensor: the estimated RTF of target speech
Tensor of dimension `(..., freq, channel, 1)`
"""
if isinstance(reference_channel, int):
ref_vector = torch.zeros(
psd_n.size()[:-3] + psd_n.size()[-1:], device=psd_n.device, dtype=psd_n.dtype
) # (..., channel)
ref_vector[..., reference_channel].fill_(1)
else:
ref_vector = reference_channel
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
rtf = torch.einsum("...fec,...c->...fe", [phi, reference_vector])
rtf = rtf.unsqueeze(-1)
rtf = torch.einsum("...fec,...c->...fe", [phi, ref_vector])
rtf = rtf.unsqueeze(-1) # (..., freq, channel, 1)
rtf = torch.matmul(phi, rtf)
rtf = torch.matmul(psd_s, rtf)
return rtf


def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
r"""Apply the beamforming weight to the noisy multi-channel spectrograms to get the enhanced spectrogram.
r"""Apply the beamforming weight to the noisy multi-channel spectrograms to get the single-channel enhanced spectrogram.
.. math::
\hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
Expand All @@ -1865,13 +1896,13 @@ def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
:math:`\textbf{Y}` is the multi-channel spectrogram for the :math:`f`-th frequency bin.
Args:
beamform_weights (Tensor): beamforming weight matrix
beamform_weights (Tensor): The beamforming weight matrix.
Tensor of dimension `(..., freq, channel)`
specgram (Tensor): multi-channel noisy STFT
specgram (Tensor): The multi-channel noisy spectrogram.
Tensor of dimension `(..., channel, freq, time)`
Returns:
Tensor: the enhanced STFT
Tensor: The single-channel enhanced spectrogram.
Tensor of dimension `(..., freq, time)`
"""
# (..., channel) x (..., channel, freq, time) -> (..., freq, time)
Expand Down
18 changes: 10 additions & 8 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
import warnings
from typing import Callable, Optional, Tuple
from typing import Callable, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -1656,7 +1656,10 @@ class MVDR(torch.nn.Module):
PSD matrices of speech and noise, respectively.
Args:
ref_channel (int, optional): the reference channel for beamforming. (Default: ``0``)
ref_channel (int or Tensor, optional): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`.
(Default: ``0``)
solution (str, optional): the solution to get MVDR weight.
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks. (Default: ``False``)
Expand All @@ -1680,7 +1683,7 @@ class MVDR(torch.nn.Module):

def __init__(
self,
ref_channel: int = 0,
ref_channel: Union[Tensor, int] = 0,
solution: str = "ref_channel",
multi_mask: bool = False,
diag_loading: bool = True,
Expand Down Expand Up @@ -1824,19 +1827,18 @@ def forward(self, specgram: Tensor, mask_s: Tensor, mask_n: Optional[Tensor] = N
psd_s = self.psd(specgram, mask_s) # (..., freq, time, channel, channel)
psd_n = self.psd(specgram, mask_n) # (..., freq, time, channel, channel)

u = torch.zeros(specgram.size()[:-2], device=specgram.device, dtype=torch.cdouble) # (..., channel)
u[..., self.ref_channel].fill_(1)

if self.online:
psd_s, psd_n = self._get_updated_psds(psd_s, psd_n, mask_s, mask_n)

if self.solution == "ref_channel":
w_mvdr = F.compute_mvdr_weights_souden(psd_s, psd_n, u, self.diag_loading, self.diag_eps, self.eps)
w_mvdr = F.compute_mvdr_weights_souden(
psd_s, psd_n, self.ref_channel, self.diag_loading, self.diag_eps, self.eps
)
else:
if self.solution == "stv_evd":
rtf = F.compute_rtf_evd(psd_s)
else:
rtf = F.compute_rtf_power(psd_s, psd_n, u)
rtf = F.compute_rtf_power(psd_s, psd_n, self.ref_channel)
w_mvdr = F.compute_mvdr_weights_rtf(
rtf, psd_n, self.ref_channel, diagonal_loading=self.diag_loading, diag_eps=self.diag_eps, eps=self.eps
)
Expand Down

0 comments on commit c5c04d1

Please sign in to comment.