diff --git a/espnet2/enh/espnet_model.py b/espnet2/enh/espnet_model.py index 94c16e44f8f..75bb57094f4 100644 --- a/espnet2/enh/espnet_model.py +++ b/espnet2/enh/espnet_model.py @@ -15,6 +15,7 @@ from espnet2.enh.loss.criterions.time_domain import TimeDomainLoss from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper from espnet2.enh.separator.abs_separator import AbsSeparator +from espnet2.enh.separator.dan_separator import DANSeparator from espnet2.torch_utils.device_funcs import force_gatherable from espnet2.train.abs_espnet_model import AbsESPnetModel @@ -135,9 +136,11 @@ def forward( speech_ref = speech_ref[..., : speech_lengths.max()] speech_ref = speech_ref.unbind(dim=1) additional = {} - additional["feature_ref"] = [ - self.encoder(r, speech_lengths)[0] for r in speech_ref - ] + # Additional data is required in Deep Attractor Network + if isinstance(self.separator, DANSeparator): + additional["feature_ref"] = [ + self.encoder(r, speech_lengths)[0] for r in speech_ref + ] speech_mix = speech_mix[:, : speech_lengths.max()] diff --git a/espnet2/enh/loss/criterions/tf_domain.py b/espnet2/enh/loss/criterions/tf_domain.py index c129e7458de..c4a383088cb 100644 --- a/espnet2/enh/loss/criterions/tf_domain.py +++ b/espnet2/enh/loss/criterions/tf_domain.py @@ -346,6 +346,7 @@ def forward(self, ref, inf) -> torch.Tensor: Reference: Independent Vector Analysis with Deep Neural Network Source Priors; Li et al 2020; https://arxiv.org/abs/2008.11273 + Args: ref: (Batch, T, F) or (Batch, T, C, F) inf: (Batch, T, F) or (Batch, T, C, F) diff --git a/espnet2/tasks/enh.py b/espnet2/tasks/enh.py index c5ea374df3d..fd0742359da 100644 --- a/espnet2/tasks/enh.py +++ b/espnet2/tasks/enh.py @@ -74,8 +74,11 @@ classes=dict( asteroid=AsteroidModel_Converter, conformer=ConformerSeparator, + dan=DANSeparator, dc_crn=DC_CRNSeparator, dccrn=DCCRNSeparator, + dpcl=DPCLSeparator, + dpcl_e2e=DPCLE2ESeparator, dprnn=DPRNNSeparator, fasnet=FaSNetSeparator, rnn=RNNSeparator, @@ -83,9 +86,6 @@ svoice=SVoiceSeparator, tcn=TCNSeparator, transformer=TransformerSeparator, - dpcl=DPCLSeparator, - dpcl_e2e=DPCLE2ESeparator, - dan=DANSeparator, wpe_beamformer=NeuralBeamformer, ), type_check=AbsSeparator,