From c1d9be5f4f9eb32bc75fb7a8b2fe406aa997946c Mon Sep 17 00:00:00 2001 From: earthmanylf <411214987@qq.com> Date: Fri, 25 Feb 2022 15:30:46 +0800 Subject: [PATCH] update for dpcl and dan --- .../enh1/conf/tuning/train_enh_dan_tf.yaml | 65 +++++++ .../enh1/conf/tuning/train_enh_dpcl.yaml | 62 +++++++ .../enh1/conf/tuning/train_enh_mdc.yaml | 62 +++++++ espnet2/enh/espnet_model.py | 8 +- espnet2/enh/loss/criterions/tf_domain.py | 92 ++++++++++ espnet2/enh/loss/wrappers/dpcl_solver.py | 32 ++++ espnet2/enh/separator/dan_separator.py | 158 ++++++++++++++++++ espnet2/enh/separator/dpcl_separator.py | 134 +++++++++++++++ espnet2/tasks/enh.py | 9 +- .../enh/loss/criterions/test_dpcl_loss.py | 26 +++ .../enh/loss/wrappers/test_dpcl_solver.py | 17 ++ .../enh/separator/test_dan_separator.py | 130 ++++++++++++++ .../enh/separator/test_dpcl_separator.py | 114 +++++++++++++ 13 files changed, 905 insertions(+), 4 deletions(-) create mode 100644 egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dan_tf.yaml create mode 100644 egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dpcl.yaml create mode 100644 egs2/wsj0_2mix/enh1/conf/tuning/train_enh_mdc.yaml create mode 100644 espnet2/enh/loss/wrappers/dpcl_solver.py create mode 100644 espnet2/enh/separator/dan_separator.py create mode 100644 espnet2/enh/separator/dpcl_separator.py create mode 100644 test/espnet2/enh/loss/criterions/test_dpcl_loss.py create mode 100644 test/espnet2/enh/loss/wrappers/test_dpcl_solver.py create mode 100644 test/espnet2/enh/separator/test_dan_separator.py create mode 100644 test/espnet2/enh/separator/test_dpcl_separator.py diff --git a/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dan_tf.yaml b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dan_tf.yaml new file mode 100644 index 00000000000..d1995a99894 --- /dev/null +++ b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dan_tf.yaml @@ -0,0 +1,65 @@ +optim: adam +init: xavier_uniform +max_epoch: 100 +batch_type: folded +batch_size: 8 +iterator_type: chunk +chunk_length: 32000 +num_workers: 4 +optim_conf: + lr: 1.0e-04 + eps: 1.0e-08 + weight_decay: 1.0e-7 +patience: 10 +val_scheduler_criterion: +- valid +- loss +best_model_criterion: +- - valid + - si_snr + - max +- - valid + - loss + - min +keep_nbest_models: 1 +scheduler: reducelronplateau +scheduler_conf: + mode: min + factor: 0.7 + patience: 1 + +# A list for criterions +# The overlall loss in the multi-task learning will be: +# loss = weight_1 * loss_1 + ... + weight_N * loss_N +# The default `weight` for each sub-loss is 1.0 +criterions: + # The first criterion + - name: mse + conf: + compute_on_mask: False + mask_type: PSM + # the wrapper for the current criterion + # PIT is widely used in the speech separation task + wrapper: pit + wrapper_conf: + weight: 1.0 + +encoder: stft +encoder_conf: + n_fft: 256 + hop_length: 64 +decoder: stft +decoder_conf: + n_fft: 256 + hop_length: 64 +separator: dan +separator_conf: + rnn_type: blstm + num_spk: 2 + nonlinear: tanh + layer: 4 + unit: 600 + dropout: 0.1 + emb_D: 20 + + diff --git a/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dpcl.yaml b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dpcl.yaml new file mode 100644 index 00000000000..58a06679107 --- /dev/null +++ b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dpcl.yaml @@ -0,0 +1,62 @@ +optim: adam +init: xavier_uniform +max_epoch: 100 +batch_type: folded +batch_size: 8 +num_workers: 4 +optim_conf: + lr: 1.0e-03 + eps: 1.0e-08 + weight_decay: 1.0e-7 +patience: 10 +val_scheduler_criterion: +- valid +- loss +best_model_criterion: +- - valid + - si_snr + - max +- - valid + - loss + - min +keep_nbest_models: 1 +scheduler: reducelronplateau +scheduler_conf: + mode: min + factor: 0.7 + patience: 1 + +# A list for criterions +# The overlall loss in the multi-task learning will be: +# loss = weight_1 * loss_1 + ... + weight_N * loss_N +# The default `weight` for each sub-loss is 1.0 +criterions: + # The first criterion + - name: dpcl + conf: + loss_type: dpcl # "dpcl" or "mdc", "dpcl" means the origin loss in Deep Clustering and "mdc" means Manifold-Aware Deep Clustering + # the wrapper for the current criterion + # PIT is widely used in the speech separation task + wrapper: dpcl + wrapper_conf: + weight: 1.0 + +encoder: stft +encoder_conf: + n_fft: 256 + hop_length: 128 +decoder: stft +decoder_conf: + n_fft: 256 + hop_length: 128 +separator: dpcl +separator_conf: + rnn_type: blstm + num_spk: 2 + nonlinear: relu + layer: 2 + unit: 500 + dropout: 0.1 + emb_D: 40 + + diff --git a/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_mdc.yaml b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_mdc.yaml new file mode 100644 index 00000000000..c093aca6944 --- /dev/null +++ b/egs2/wsj0_2mix/enh1/conf/tuning/train_enh_mdc.yaml @@ -0,0 +1,62 @@ +optim: adam +init: xavier_uniform +max_epoch: 100 +batch_type: folded +batch_size: 8 +num_workers: 4 +optim_conf: + lr: 1.0e-03 + eps: 1.0e-08 + weight_decay: 1.0e-7 +patience: 10 +val_scheduler_criterion: +- valid +- loss +best_model_criterion: +- - valid + - si_snr + - max +- - valid + - loss + - min +keep_nbest_models: 1 +scheduler: reducelronplateau +scheduler_conf: + mode: min + factor: 0.7 + patience: 1 + +# A list for criterions +# The overlall loss in the multi-task learning will be: +# loss = weight_1 * loss_1 + ... + weight_N * loss_N +# The default `weight` for each sub-loss is 1.0 +criterions: + # The first criterion + - name: dpcl + conf: + loss_type: mdc # "dpcl" or "mdc", "dpcl" means the origin loss in Deep Clustering and "mdc" means Manifold-Aware Deep Clustering + # the wrapper for the current criterion + # PIT is widely used in the speech separation task + wrapper: dpcl + wrapper_conf: + weight: 1.0 + +encoder: stft +encoder_conf: + n_fft: 256 + hop_length: 128 +decoder: stft +decoder_conf: + n_fft: 256 + hop_length: 128 +separator: dpcl +separator_conf: + rnn_type: blstm + num_spk: 2 + nonlinear: relu + layer: 2 + unit: 500 + dropout: 0.1 + emb_D: 40 + + diff --git a/espnet2/enh/espnet_model.py b/espnet2/enh/espnet_model.py index e75e7ec0216..72bf902d560 100644 --- a/espnet2/enh/espnet_model.py +++ b/espnet2/enh/espnet_model.py @@ -132,12 +132,14 @@ def forward( # for data-parallel speech_ref = speech_ref[..., : speech_lengths.max()] speech_ref = speech_ref.unbind(dim=1) + sep_others = {} + sep_others["feature_ref"] = [self.encoder(r, speech_lengths)[0] for r in speech_ref] speech_mix = speech_mix[:, : speech_lengths.max()] # model forward feature_mix, flens = self.encoder(speech_mix, speech_lengths) - feature_pre, flens, others = self.separator(feature_mix, flens) + feature_pre, flens, others = self.separator(feature_mix, flens, sep_others) if feature_pre is not None: speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in feature_pre] else: @@ -156,7 +158,7 @@ def forward( # only select one channel as the reference speech_ref = [sr[..., self.ref_channel] for sr in speech_ref] # for the time domain criterions - l, s, o = loss_wrapper(speech_ref, speech_pre, o) + l, s, o = loss_wrapper(speech_ref, speech_pre, others) elif isinstance(criterion, FrequencyDomainLoss): # for the time-frequency domain criterions if criterion.compute_on_mask: @@ -178,7 +180,7 @@ def forward( tf_ref = [self.encoder(sr, speech_lengths)[0] for sr in speech_ref] tf_pre = feature_pre - l, s, o = loss_wrapper(tf_ref, tf_pre, o) + l, s, o = loss_wrapper(tf_ref, tf_pre, others) loss += l * loss_wrapper.weight stats.update(s) diff --git a/espnet2/enh/loss/criterions/tf_domain.py b/espnet2/enh/loss/criterions/tf_domain.py index 7be48e010f3..37d452a6e38 100644 --- a/espnet2/enh/loss/criterions/tf_domain.py +++ b/espnet2/enh/loss/criterions/tf_domain.py @@ -4,6 +4,8 @@ from functools import reduce import torch +import torch.nn.functional as F +import math from espnet2.enh.layers.complex_utils import is_complex from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss @@ -185,3 +187,93 @@ def forward(self, ref, inf) -> torch.Tensor: "Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape) ) return l1loss + + +class FrequencyDomainDPCL(FrequencyDomainLoss): + def __init__(self, compute_on_mask=False, mask_type="IBM", loss_type="dpcl"): + super().__init__() + self._compute_on_mask = compute_on_mask + self._mask_type = mask_type + self._loss_type = loss_type + + @property + def compute_on_mask(self) -> bool: + return self._compute_on_mask + + @property + def mask_type(self) -> str: + return self._mask_type + + @property + def name(self) -> str: + return "dpcl" + + def forward(self, ref, inf) -> torch.Tensor: + """time-frequency Deep Clustering loss. + + References: + [1] Deep clustering: Discriminative embeddings for segmentation and separation; + John R. Hershey. et al., 2016; + https://ieeexplore.ieee.org/document/7471631 + [2] Manifold-Aware Deep Clustering: Maximizing Angles Between Embedding Vectors Based on Regular Simplex; + Tanaka, K. et al., 2021; + https://www.isca-speech.org/archive/interspeech_2021/tanaka21_interspeech.html + + Args: + ref: List[(Batch, T, F) * spks] + inf: (Batch, T*F, D) + Returns: + loss: (Batch,) + """ + assert len(ref) > 0 + num_spk = len(ref) + + # Compute the ref for Deep Clustering[1][2] + if self._loss_type == "dpcl": + r = torch.zeros_like(abs(ref[0])) + B = ref[0].shape[0] + for i in range(0, num_spk): + flags = [abs(ref[i]) >= abs(n) for n in ref] + mask = reduce(lambda x, y: x * y, flags) + mask = mask.int() * i + r += mask + r = r.contiguous().view(-1,).long() + re = F.one_hot(r, num_classes=num_spk) + re = re.contiguous().view(B, -1, num_spk) + elif self._loss_type == "mdc": + B = ref[0].shape[0] + manifold_vector = ( + torch.ones(num_spk, num_spk, device=inf.device) + * (-1 / num_spk) + * math.sqrt(num_spk / (num_spk - 1)) + ) + for i in range(num_spk): + manifold_vector[i][i] = ((num_spk - 1) / num_spk) * math.sqrt( + num_spk / (num_spk - 1) + ) + + re = torch.zeros( + ref[0].shape[0], ref[0].shape[1], ref[0].shape[2], num_spk + ).to(inf.device) + for i in range(0, num_spk): + flags = [abs(ref[i]) >= abs(n) for n in ref] + mask = reduce(lambda x, y: x * y, flags) + mask = mask.int() + re[mask == 1] = manifold_vector[i] + re = re.contiguous().view(B, -1, num_spk) + else: + raise ValueError( + 'Invalid loss type error: {}, the loss type must be "dpcl" or "mdc"'.format( + self._loss_type + ) + ) + + V2 = torch.matmul(torch.transpose(inf, 2, 1), inf).pow(2).sum(dim=(1, 2)) + Y2 = ( + torch.matmul(torch.transpose(re, 2, 1).float(), re.float()) + .pow(2) + .sum(dim=(1, 2)) + ) + VY = torch.matmul(torch.transpose(inf, 2, 1), re.float()).pow(2).sum(dim=(1, 2)) + + return V2 + Y2 - 2 * VY diff --git a/espnet2/enh/loss/wrappers/dpcl_solver.py b/espnet2/enh/loss/wrappers/dpcl_solver.py new file mode 100644 index 00000000000..7ae09d7d326 --- /dev/null +++ b/espnet2/enh/loss/wrappers/dpcl_solver.py @@ -0,0 +1,32 @@ +import torch + +from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss +from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper + + +class DPCLSolver(AbsLossWrapper): + def __init__(self, criterion: AbsEnhLoss, weight=1.0): + super().__init__() + self.criterion = criterion + self.weight = weight + + def forward(self, ref, inf, others={}): + """An naive DPCL solver + + Args: + ref (List[torch.Tensor]): [(batch, ...), ...] x n_spk + inf (List[torch.Tensor]): [(batch, ...), ...] + + Returns: + loss: (torch.Tensor): minimum loss with the best permutation + stats: dict, for collecting training status + others: reserved + """ + assert "V" in others + + loss = self.criterion(ref, others["V"]).mean() + + stats = dict() + stats[self.criterion.name] = loss.detach() + + return loss.mean(), stats, {} diff --git a/espnet2/enh/separator/dan_separator.py b/espnet2/enh/separator/dan_separator.py new file mode 100644 index 00000000000..8ba947298ef --- /dev/null +++ b/espnet2/enh/separator/dan_separator.py @@ -0,0 +1,158 @@ +from collections import OrderedDict +from typing import Dict, List +from typing import Tuple +from typing import Union +from functools import reduce + +import torch +import torch.nn.functional as Fun +from torch_complex.tensor import ComplexTensor + +from espnet.nets.pytorch_backend.rnn.encoders import RNN +from espnet2.enh.separator.abs_separator import AbsSeparator + + +class DANSeparator(AbsSeparator): + def __init__( + self, + input_dim: int, + rnn_type: str = "blstm", + num_spk: int = 2, + nonlinear: str = "tanh", + layer: int = 2, + unit: int = 512, + emb_D: int = 40, + dropout: float = 0.0, + ): + """Deep Attractor Network Separator + + Reference: + DEEP ATTRACTOR NETWORK FOR SINGLE-MICROPHONE SPEAKER SEPARATION; + Zhuo Chen. et al., 2017; + https://pubmed.ncbi.nlm.nih.gov/29430212/ + + Args: + input_dim: input feature dimension + rnn_type: string, select from 'blstm', 'lstm' etc. + bidirectional: bool, whether the inter-chunk RNN layers are bidirectional. + num_spk: number of speakers + nonlinear: the nonlinear function for mask estimation, + select from 'relu', 'tanh', 'sigmoid' + layer: int, number of stacked RNN layers. Default is 3. + unit: int, dimension of the hidden state. + emb_D: int, dimension of the attribute vector for one tf-bin. + dropout: float, dropout ratio. Default is 0. + """ + super().__init__() + + self._num_spk = num_spk + + self.blstm = RNN( + idim=input_dim, + elayers=layer, + cdim=unit, + hdim=unit, + dropout=dropout, + typ=rnn_type, + ) + + self.linear = torch.nn.Linear(unit, input_dim * emb_D) + + if nonlinear not in ("sigmoid", "relu", "tanh"): + raise ValueError("Not supporting nonlinear={}".format(nonlinear)) + + self.nonlinear = { + "sigmoid": torch.nn.Sigmoid(), + "relu": torch.nn.ReLU(), + "tanh": torch.nn.Tanh(), + }[nonlinear] + + self.D = emb_D + + def forward( + self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, o=None + ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: + """Forward. + + Args: + input (torch.Tensor or ComplexTensor): Encoded feature [B, T, F] + ilens (torch.Tensor): input lengths [Batch] + origin List[ComplexTensor(B, T, [C,] F), ...]: Origin data + + Returns: + masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] + ilens (torch.Tensor): (B,) + others predicted data, e.g. V: OrderedDict[ + 'V': torch.Tensor(Batch, T * Frames, D), + ] + """ + + # if complex spectrum, + if isinstance(input, ComplexTensor): + feature = abs(input) + else: + feature = input + B, T, F = input.shape + # x:(B, T, F) + x, ilens, _ = self.blstm(feature, ilens) + # x:(B, T, F*D) + x = self.linear(x) + # x:(B, T, F*D) + x = self.nonlinear(x) + # V:(B, T*F, D) + V = x.contiguous().view(B, T*F, -1) + + # Compute the attractors + if self.training: + assert o is not None and "feature_ref" in o + origin = o["feature_ref"] + Y_t = torch.zeros(B, T, F, device=origin[0].device) + for i in range(self._num_spk): + flags = [abs(origin[i]) >= abs(n) for n in origin] + Y = reduce(lambda x, y: x * y, flags) + Y = Y.int() * i + Y_t += Y + Y_t = Y_t.contiguous().view(-1,).long() + Y = Fun.one_hot(Y_t, num_classes=self._num_spk) + Y = Y.contiguous().view(B, -1, self._num_spk).float() + + # v_y:(B, D, spks) + v_y = torch.bmm(torch.transpose(V, 1, 2), Y) + # sum_y:(B, D, spks) + sum_y = torch.sum(Y, 1, keepdim=True).expand_as(v_y) + # attractor:(B, D, spks) + attractor = v_y / (sum_y + 1e-8) + else: + # K-means for batch + centers = V[:,:self._num_spk,:].detach() + dist = torch.empty(B, T*F, self._num_spk).to(V.device) + last_label = torch.zeros(B, T*F).to(V.device) + while True: + for i in range(self._num_spk): + dist[:,:,i] = torch.sum((V-centers[:,i,:].unsqueeze(1))**2, dim=2) + label = dist.argmin(dim=2) + if torch.sum(label != last_label) == 0: + break + last_label = label + for b in range(B): + for i in range(self._num_spk): + centers[b,i] = V[b, label[b]==i].mean(dim=0) + attractor = centers.permute(0,2,1) + + # calculate the distance between embeddings and attractors and generate the masks + # dist:(B, T*F, spks) + dist = torch.bmm(V, attractor) + masks = torch.softmax(dist,dim=2) + masks = masks.contiguous().view(B, T, F, self._num_spk).unbind(dim=3) + + masked = [input * m for m in masks] + + others = OrderedDict( + zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) + ) + + return masked, ilens, others + + @property + def num_spk(self): + return self._num_spk diff --git a/espnet2/enh/separator/dpcl_separator.py b/espnet2/enh/separator/dpcl_separator.py new file mode 100644 index 00000000000..b4e31945276 --- /dev/null +++ b/espnet2/enh/separator/dpcl_separator.py @@ -0,0 +1,134 @@ +from collections import OrderedDict +from typing import Dict, List +from typing import Tuple +from typing import Union + +import torch +from torch_complex.tensor import ComplexTensor + +from espnet.nets.pytorch_backend.rnn.encoders import RNN +from espnet2.enh.separator.abs_separator import AbsSeparator + + +class DPCLSeparator(AbsSeparator): + def __init__( + self, + input_dim: int, + rnn_type: str = "blstm", + num_spk: int = 2, + nonlinear: str = "tanh", + layer: int = 2, + unit: int = 512, + emb_D: int = 40, + dropout: float = 0.0, + ): + """Deep Clustering Separator + + References: + [1] Deep clustering: Discriminative embeddings for segmentation and separation; + John R. Hershey. et al., 2016; + https://ieeexplore.ieee.org/document/7471631 + [2] Manifold-Aware Deep Clustering: Maximizing Angles Between Embedding Vectors Based on Regular Simplex; + Tanaka, K. et al., 2021; + https://www.isca-speech.org/archive/interspeech_2021/tanaka21_interspeech.html + + Args: + input_dim: input feature dimension + rnn_type: string, select from 'blstm', 'lstm' etc. + bidirectional: bool, whether the inter-chunk RNN layers are bidirectional. + num_spk: number of speakers + nonlinear: the nonlinear function for mask estimation, + select from 'relu', 'tanh', 'sigmoid' + layer: int, number of stacked RNN layers. Default is 3. + unit: int, dimension of the hidden state. + emb_D: int, dimension of the feature vector for a tf-bin. + dropout: float, dropout ratio. Default is 0. + """ + super().__init__() + + self._num_spk = num_spk + + self.blstm = RNN( + idim=input_dim, + elayers=layer, + cdim=unit, + hdim=unit, + dropout=dropout, + typ=rnn_type, + ) + + self.linear = torch.nn.Linear(unit, input_dim * emb_D) + + if nonlinear not in ("sigmoid", "relu", "tanh"): + raise ValueError("Not supporting nonlinear={}".format(nonlinear)) + + self.nonlinear = { + "sigmoid": torch.nn.Sigmoid(), + "relu": torch.nn.ReLU(), + "tanh": torch.nn.Tanh(), + }[nonlinear] + + self.D = emb_D + + def forward( + self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, o=None + ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: + """Forward. + + Args: + input (torch.Tensor or ComplexTensor): Encoded feature [B, T, F] + ilens (torch.Tensor): input lengths [Batch] + + Returns: + masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] + ilens (torch.Tensor): (B,) + others predicted data, e.g. V: OrderedDict[ + 'V': torch.Tensor(Batch, T * F, D), + ] + """ + + # if complex spectrum, + if isinstance(input, ComplexTensor): + feature = abs(input) + else: + feature = input + B, T, F = input.shape + # x:(B, T, F) + x, ilens, _ = self.blstm(feature, ilens) + # x:(B, T, F*D) + x = self.linear(x) + # x:(B, T, F*D) + x = self.nonlinear(x) + V = x.view(B, -1, self.D) + + if self.training: + masked = None + else: + # K-means for batch + centers = V[:, : self._num_spk, :].detach() + dist = torch.empty(B, T * F, self._num_spk).to(V.device) + last_label = torch.zeros(B, T * F).to(V.device) + while True: + for i in range(self._num_spk): + dist[:, :, i] = torch.sum( + (V - centers[:, i, :].unsqueeze(1)) ** 2, dim=2 + ) + label = dist.argmin(dim=2) + if torch.sum(label != last_label) == 0: + break + last_label = label + for b in range(B): + for i in range(self._num_spk): + centers[b, i] = V[b, label[b] == i].mean(dim=0) + label = label.view(B, T, F) + masked = [] + for i in range(self._num_spk): + masked.append(input * (label == i)) + + others = OrderedDict({"V": V},) + + return masked, ilens, others + + @property + def num_spk(self): + return self._num_spk diff --git a/espnet2/tasks/enh.py b/espnet2/tasks/enh.py index 2a722cba554..9b862d708c4 100644 --- a/espnet2/tasks/enh.py +++ b/espnet2/tasks/enh.py @@ -23,12 +23,14 @@ from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainL1 from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainMSE +from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainDPCL from espnet2.enh.loss.criterions.time_domain import CISDRLoss from espnet2.enh.loss.criterions.time_domain import SISNRLoss from espnet2.enh.loss.criterions.time_domain import SNRLoss from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper from espnet2.enh.loss.wrappers.fixed_order import FixedOrderSolver from espnet2.enh.loss.wrappers.pit_solver import PITSolver +from espnet2.enh.loss.wrappers.dpcl_solver import DPCLSolver from espnet2.enh.separator.abs_separator import AbsSeparator from espnet2.enh.separator.asteroid_models import AsteroidModel_Converter from espnet2.enh.separator.conformer_separator import ConformerSeparator @@ -37,6 +39,8 @@ from espnet2.enh.separator.rnn_separator import RNNSeparator from espnet2.enh.separator.tcn_separator import TCNSeparator from espnet2.enh.separator.transformer_separator import TransformerSeparator +from espnet2.enh.separator.dpcl_separator import DPCLSeparator +from espnet2.enh.separator.dan_separator import DANSeparator from espnet2.tasks.abs_task import AbsTask from espnet2.torch_utils.initialize import initialize from espnet2.train.class_choices import ClassChoices @@ -62,6 +66,8 @@ dprnn=DPRNNSeparator, transformer=TransformerSeparator, conformer=ConformerSeparator, + dpcl=DPCLSeparator, + dan=DANSeparator, wpe_beamformer=NeuralBeamformer, asteroid=AsteroidModel_Converter, ), @@ -78,7 +84,7 @@ loss_wrapper_choices = ClassChoices( name="loss_wrappers", - classes=dict(pit=PITSolver, fixed_order=FixedOrderSolver), + classes=dict(pit=PITSolver, fixed_order=FixedOrderSolver, dpcl=DPCLSolver), type_check=AbsLossWrapper, default=None, ) @@ -91,6 +97,7 @@ si_snr=SISNRLoss, mse=FrequencyDomainMSE, l1=FrequencyDomainL1, + dpcl=FrequencyDomainDPCL, ), type_check=AbsEnhLoss, default=None, diff --git a/test/espnet2/enh/loss/criterions/test_dpcl_loss.py b/test/espnet2/enh/loss/criterions/test_dpcl_loss.py new file mode 100644 index 00000000000..f949fbb6874 --- /dev/null +++ b/test/espnet2/enh/loss/criterions/test_dpcl_loss.py @@ -0,0 +1,26 @@ +import pytest +import torch + +from torch_complex import ComplexTensor + +from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainL1 +from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainMSE +from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainDPCL + + +@pytest.mark.parametrize("criterion_class", [FrequencyDomainDPCL]) +@pytest.mark.parametrize("mask_type", ["IBM"]) +@pytest.mark.parametrize("compute_on_mask", [False]) +@pytest.mark.parametrize("loss_type", ["dpcl", "mdc"]) +def test_tf_domain_criterion_forward(criterion_class, mask_type, compute_on_mask, loss_type): + + criterion = criterion_class(compute_on_mask=compute_on_mask, mask_type=mask_type, loss_type=loss_type) + + batch = 2 + inf = torch.rand(batch, 10*200, 40) + ref_spec = [ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)), ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)), ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200))] + + ref = [abs(r) for r in ref_spec] + + loss = criterion(ref, inf) + assert loss.shape == (batch,) diff --git a/test/espnet2/enh/loss/wrappers/test_dpcl_solver.py b/test/espnet2/enh/loss/wrappers/test_dpcl_solver.py new file mode 100644 index 00000000000..65bace71d2c --- /dev/null +++ b/test/espnet2/enh/loss/wrappers/test_dpcl_solver.py @@ -0,0 +1,17 @@ +import pytest +import torch + +from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainDPCL +from espnet2.enh.loss.wrappers.dpcl_solver import DPCLSolver + + +@pytest.mark.parametrize("num_spk", [1, 2, 3]) +def test_DPCLSolver_forward(num_spk): + + batch = 2 + o={'V': torch.rand(batch, 10*200, 40)} + inf = [torch.rand(batch, 10, 200) for spk in range(num_spk)] + ref = [inf[num_spk - spk - 1] for spk in range(num_spk)] # reverse inf as ref + solver = DPCLSolver(FrequencyDomainDPCL()) + + loss, stats, others = solver(ref, inf, o) \ No newline at end of file diff --git a/test/espnet2/enh/separator/test_dan_separator.py b/test/espnet2/enh/separator/test_dan_separator.py new file mode 100644 index 00000000000..473f2316201 --- /dev/null +++ b/test/espnet2/enh/separator/test_dan_separator.py @@ -0,0 +1,130 @@ +import pytest + +import torch +from torch import Tensor +from torch_complex import ComplexTensor + +from espnet2.enh.separator.dan_separator import DANSeparator + + +@pytest.mark.parametrize("input_dim", [5]) +@pytest.mark.parametrize("rnn_type", ["blstm"]) +@pytest.mark.parametrize("layer", [1, 3]) +@pytest.mark.parametrize("unit", [8]) +@pytest.mark.parametrize("dropout", [0.0, 0.2]) +@pytest.mark.parametrize("num_spk", [2]) +@pytest.mark.parametrize("emb_D", [40]) +@pytest.mark.parametrize("nonlinear", ["relu", "sigmoid", "tanh"]) +def test_dan_separator_forward_backward_complex( + input_dim, rnn_type, layer, unit, dropout, num_spk, emb_D, nonlinear +): + model = DANSeparator( + input_dim=input_dim, + rnn_type=rnn_type, + layer=layer, + unit=unit, + dropout=dropout, + num_spk=num_spk, + emb_D=emb_D, + nonlinear=nonlinear, + ) + model.train() + + real = torch.rand(2, 10, input_dim) + imag = torch.rand(2, 10, input_dim) + x = ComplexTensor(real, imag) + x_lens = torch.tensor([10, 8], dtype=torch.long) + + o = [] + for i in range(num_spk): + o.append(ComplexTensor(real, imag)) + + sep_others={} + sep_others["feature_ref"] = o + + masked, flens, others = model(x, ilens=x_lens, o=sep_others) + + assert isinstance(masked[0], ComplexTensor) + assert len(masked) == num_spk + + masked[0].abs().mean().backward() + + +@pytest.mark.parametrize("input_dim", [5]) +@pytest.mark.parametrize("rnn_type", ["blstm"]) +@pytest.mark.parametrize("layer", [1, 3]) +@pytest.mark.parametrize("unit", [8]) +@pytest.mark.parametrize("dropout", [0.0, 0.2]) +@pytest.mark.parametrize("num_spk", [1, 2]) +@pytest.mark.parametrize("emb_D", [40]) +@pytest.mark.parametrize("nonlinear", ["relu", "sigmoid", "tanh"]) +def test_dan_separator_forward_backward_real( + input_dim, rnn_type, layer, unit, dropout, num_spk, emb_D, nonlinear +): + model = DANSeparator( + input_dim=input_dim, + rnn_type=rnn_type, + layer=layer, + unit=unit, + dropout=dropout, + num_spk=num_spk, + emb_D=emb_D, + nonlinear=nonlinear, + ) + model.train() + + x = torch.rand(2, 10, input_dim) + x_lens = torch.tensor([10, 8], dtype=torch.long) + + o = [] + for i in range(num_spk): + o.append(ComplexTensor(x, x)) + + sep_others={} + sep_others["feature_ref"] = o + + masked, flens, others = model(x, ilens=x_lens, o=sep_others) + + assert isinstance(masked[0], Tensor) + assert len(masked) == num_spk + + masked[0].abs().mean().backward() + + +def test_dan_separator_invalid_type(): + with pytest.raises(ValueError): + DANSeparator( + input_dim=10, + rnn_type="rnn", + layer=2, + unit=10, + dropout=0.1, + num_spk=2, + emb_D=40, + nonlinear="fff", + ) + + +def test_dan_separator_output(): + + x = torch.rand(1, 10, 10) + x_lens = torch.tensor([10], dtype=torch.long) + + for num_spk in range(1, 4): + model = DANSeparator( + input_dim=10, + rnn_type="rnn", + layer=2, + unit=10, + dropout=0.1, + num_spk=num_spk, + emb_D=40, + nonlinear="relu", + ) + model.eval() + specs, _, others = model(x, x_lens) + assert isinstance(specs, list) + assert isinstance(others, dict) + for n in range(num_spk): + assert "mask_spk{}".format(n + 1) in others + assert specs[n].shape == others["mask_spk{}".format(n + 1)].shape diff --git a/test/espnet2/enh/separator/test_dpcl_separator.py b/test/espnet2/enh/separator/test_dpcl_separator.py new file mode 100644 index 00000000000..1905b6ef092 --- /dev/null +++ b/test/espnet2/enh/separator/test_dpcl_separator.py @@ -0,0 +1,114 @@ +import pytest + +import torch +from torch import Tensor +from torch_complex import ComplexTensor + +from espnet2.enh.separator.dpcl_separator import DPCLSeparator + + +@pytest.mark.parametrize("input_dim", [5]) +@pytest.mark.parametrize("rnn_type", ["blstm"]) +@pytest.mark.parametrize("layer", [1, 3]) +@pytest.mark.parametrize("unit", [8]) +@pytest.mark.parametrize("dropout", [0.0, 0.2]) +@pytest.mark.parametrize("num_spk", [2]) +@pytest.mark.parametrize("emb_D", [40]) +@pytest.mark.parametrize("nonlinear", ["relu", "sigmoid", "tanh"]) +def test_dpcl_separator_forward_backward_complex( + input_dim, rnn_type, layer, unit, dropout, num_spk, emb_D, nonlinear +): + model = DPCLSeparator( + input_dim=input_dim, + rnn_type=rnn_type, + layer=layer, + unit=unit, + dropout=dropout, + num_spk=num_spk, + emb_D=emb_D, + nonlinear=nonlinear, + ) + model.train() + + real = torch.rand(2, 10, input_dim) + imag = torch.rand(2, 10, input_dim) + x = ComplexTensor(real, imag) + x_lens = torch.tensor([10, 8], dtype=torch.long) + + masked, flens, others = model(x, ilens=x_lens) + + assert 'V' in others + + others['V'].abs().mean().backward() + + +@pytest.mark.parametrize("input_dim", [5]) +@pytest.mark.parametrize("rnn_type", ["blstm"]) +@pytest.mark.parametrize("layer", [1, 3]) +@pytest.mark.parametrize("unit", [8]) +@pytest.mark.parametrize("dropout", [0.0, 0.2]) +@pytest.mark.parametrize("num_spk", [1, 2]) +@pytest.mark.parametrize("emb_D", [40]) +@pytest.mark.parametrize("nonlinear", ["relu", "sigmoid", "tanh"]) +def test_dpcl_separator_forward_backward_real( + input_dim, rnn_type, layer, unit, dropout, num_spk, emb_D, nonlinear +): + model = DPCLSeparator( + input_dim=input_dim, + rnn_type=rnn_type, + layer=layer, + unit=unit, + dropout=dropout, + num_spk=num_spk, + emb_D=emb_D, + nonlinear=nonlinear, + ) + model.train() + + x = torch.rand(2, 10, input_dim) + x_lens = torch.tensor([10, 8], dtype=torch.long) + + masked, flens, others = model(x, ilens=x_lens) + + assert 'V' in others + + others['V'].abs().mean().backward() + + +def test_dpcl_separator_invalid_type(): + with pytest.raises(ValueError): + DPCLSeparator( + input_dim=10, + rnn_type="rnn", + layer=2, + unit=10, + dropout=0.1, + num_spk=2, + emb_D=40, + nonlinear="fff", + ) + + +def test_dpcl_separator_output(): + + x = torch.rand(2, 10, 10) + x_lens = torch.tensor([10, 8], dtype=torch.long) + + for num_spk in range(1, 4): + model = DPCLSeparator( + input_dim=10, + rnn_type="rnn", + layer=2, + unit=10, + dropout=0.1, + num_spk=num_spk, + emb_D=40, + nonlinear="relu", + ) + model.eval() + specs, _, others = model(x, x_lens) + assert isinstance(specs, list) + assert isinstance(others, dict) + assert len(specs) == num_spk, len(specs) + for n in range(num_spk): + assert "V" in others