Skip to content

Commit

Permalink
Merge pull request espnet#4110 from earthmanylf/dpclanddan
Browse files Browse the repository at this point in the history
Merge Deep Clustering and Deep Attractor Network to enh separator
  • Loading branch information
sw005320 authored Apr 28, 2022
2 parents b7f0a5a + 406656c commit 72c1d8f
Show file tree
Hide file tree
Showing 31 changed files with 1,493 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ To train the neural vocoder, please check the following repositories:
- Multi-speaker speech separation
- Unified encoder-separator-decoder structure for time-domain and frequency-domain models
- Encoder/Decoder: STFT/iSTFT, Convolution/Transposed-Convolution
- Separators: BLSTM, Transformer, Conformer, [TasNet](https://arxiv.org/abs/1809.07454), [DPRNN](https://arxiv.org/abs/1910.06379), [DC-CRN](https://web.cse.ohio-state.edu/~wang.77/papers/TZW.taslp21.pdf), [DCCRN](https://arxiv.org/abs/2008.00264), Neural Beamformers, etc.
- Separators: BLSTM, Transformer, Conformer, [TasNet](https://arxiv.org/abs/1809.07454), [DPRNN](https://arxiv.org/abs/1910.06379), [SkiM](https://arxiv.org/abs/2201.10800), [SVoice](https://arxiv.org/abs/2011.02329), [DC-CRN](https://web.cse.ohio-state.edu/~wang.77/papers/TZW.taslp21.pdf), [DCCRN](https://arxiv.org/abs/2008.00264), [Deep Clustering](https://ieeexplore.ieee.org/document/7471631), [Deep Attractor Network](https://pubmed.ncbi.nlm.nih.gov/29430212/), [FaSNet](https://arxiv.org/abs/1909.13387), [iFaSNet](https://arxiv.org/abs/1910.14104), Neural Beamformers, etc.
- Flexible ASR integration: working as an individual task or as the ASR frontend
- Easy to import pretrained models from [Asteroid](https://github.com/asteroid-team/asteroid)
- Both the pre-trained models from Asteroid and the specific configuration are supported.
Expand Down
59 changes: 59 additions & 0 deletions egs2/wsj0_2mix/enh1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,65 @@
<!-- Generated by ./scripts/utils/show_enh_score.sh -->
# RESULTS
## Environments
- date: `Thu Feb 24 16:26:21 CST 2022`
- python version: `3.8.10 (default, May 19 2021, 18:05:58) [GCC 7.3.0]`
- espnet version: `espnet 0.10.7a1`
- pytorch version: `pytorch 1.5.1+cu101`
- Git hash: `c58adabbe1b83dcd0b616ecd336b4a0648334e2c`
- Commit date: `Wed Feb 16 14:20:38 2022 +0800`


## enh_train_enh_dpcl_raw

- config: conf/tuning/train_enh_dpcl.yaml
- Pretrained model: https://huggingface.co/Yulinfeng/wsj0_2mix_enh_train_enh_dpcl_raw_valid.si_snr.ave

|dataset|PESQ|STOI|SAR|SDR|SIR|SI_SNR|
|---|---|---|---|---|---|---|
|enhanced_cv_min_8k|2.18|0.84|9.63|8.59|17.31|8.04|
|enhanced_tt_min_8k|2.15|0.84|9.51|8.45|17.22|7.91|

<!-- Generated by ./scripts/utils/show_enh_score.sh -->
# RESULTS
## Environments
- date: `Thu Mar 3 17:10:03 CST 2022`
- python version: `3.8.10 (default, May 19 2021, 18:05:58) [GCC 7.3.0]`
- espnet version: `espnet 0.10.7a1`
- pytorch version: `pytorch 1.5.1+cu101`
- Git hash: `ec1acec03d109f06d829b80862e0388f7234d0d1`
- Commit date: `Fri Feb 25 14:12:45 2022 +0800`


## enh_train_enh_mdc_raw

- config: conf/tuning/train_enh_mdc.yaml
- Pretrained model: https://huggingface.co/Yulinfeng/wsj0_2mix_enh_train_enh_mdc_raw_valid.si_snr.ave

|dataset|PESQ|STOI|SAR|SDR|SIR|SI_SNR|
|---|---|---|---|---|---|---|
|enhanced_cv_min_8k|2.20|0.84|9.62|8.57|17.27|8.03|
|enhanced_tt_min_8k|2.18|0.85|9.56|8.50|17.28|7.97|

<!-- Generated by ./scripts/utils/show_enh_score.sh -->
# RESULTS
## Environments
- date: `Thu Mar 3 14:33:32 CST 2022`
- python version: `3.8.10 (default, May 19 2021, 18:05:58) [GCC 7.3.0]`
- espnet version: `espnet 0.10.7a1`
- pytorch version: `pytorch 1.5.1+cu101`
- Git hash: `ec1acec03d109f06d829b80862e0388f7234d0d1`
- Commit date: `Fri Feb 25 14:12:45 2022 +0800`


## enh_train_enh_dan_tf_raw

- config: conf/tuning/train_enh_dan_tf.yaml
- Pretrained model: https://huggingface.co/Yulinfeng/wsj0_2mix_enh_train_enh_dan_tf_raw_valid.si_snr.ave

|dataset|PESQ|STOI|SAR|SDR|SIR|SI_SNR|
|---|---|---|---|---|---|---|
|enhanced_cv_min_8k|2.68|0.88|12.28|11.01|18.03|10.48|
|enhanced_tt_min_8k|2.68|0.89|12.10|10.84|17.98|10.30|
- date: `Thu Mar 3 14:29:20 CST 2022`
- python version: `3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0]`
- espnet version: `espnet 0.10.7a1`
Expand Down
65 changes: 65 additions & 0 deletions egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dan_tf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
optim: adam
init: xavier_uniform
max_epoch: 100
batch_type: folded
batch_size: 8
iterator_type: chunk
chunk_length: 32000
num_workers: 4
optim_conf:
lr: 1.0e-04
eps: 1.0e-08
weight_decay: 1.0e-7
patience: 10
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- si_snr
- max
- - valid
- loss
- min
keep_nbest_models: 1
scheduler: reducelronplateau
scheduler_conf:
mode: min
factor: 0.7
patience: 1

# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# The first criterion
- name: mse
conf:
compute_on_mask: False
mask_type: PSM
# the wrapper for the current criterion
# PIT is widely used in the speech separation task
wrapper: pit
wrapper_conf:
weight: 1.0

encoder: stft
encoder_conf:
n_fft: 256
hop_length: 64
decoder: stft
decoder_conf:
n_fft: 256
hop_length: 64
separator: dan
separator_conf:
rnn_type: blstm
num_spk: 2
nonlinear: tanh
layer: 4
unit: 600
dropout: 0.1
emb_D: 20


62 changes: 62 additions & 0 deletions egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dpcl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
optim: adam
init: xavier_uniform
max_epoch: 100
batch_type: folded
batch_size: 8
num_workers: 4
optim_conf:
lr: 1.0e-03
eps: 1.0e-08
weight_decay: 1.0e-7
patience: 10
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- si_snr
- max
- - valid
- loss
- min
keep_nbest_models: 1
scheduler: reducelronplateau
scheduler_conf:
mode: min
factor: 0.7
patience: 1

# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# The first criterion
- name: dpcl
conf:
loss_type: dpcl # "dpcl" or "mdc", "dpcl" means the origin loss in Deep Clustering and "mdc" means Manifold-Aware Deep Clustering
# the wrapper for the current criterion
# PIT is widely used in the speech separation task
wrapper: dpcl
wrapper_conf:
weight: 1.0

encoder: stft
encoder_conf:
n_fft: 256
hop_length: 128
decoder: stft
decoder_conf:
n_fft: 256
hop_length: 128
separator: dpcl
separator_conf:
rnn_type: blstm
num_spk: 2
nonlinear: relu
layer: 2
unit: 500
dropout: 0.1
emb_D: 40


66 changes: 66 additions & 0 deletions egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dpcl_e2e.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
optim: adam
init: xavier_uniform
max_epoch: 100
batch_type: folded
batch_size: 8
num_workers: 4
optim_conf:
lr: 1.0e-03
eps: 1.0e-08
weight_decay: 1.0e-7
patience: 10
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- si_snr
- max
- - valid
- loss
- min
keep_nbest_models: 1
scheduler: reducelronplateau
scheduler_conf:
mode: min
factor: 0.7
patience: 1

# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# The first criterion
- name: mse
conf:
compute_on_mask: False
mask_type: PSM
# the wrapper for the current criterion
# PIT is widely used in the speech separation task
wrapper: pit
wrapper_conf:
weight: 1.0

encoder: stft
encoder_conf:
n_fft: 256
hop_length: 128
decoder: stft
decoder_conf:
n_fft: 256
hop_length: 128
separator: dpcl_e2e
separator_conf:
rnn_type: blstm
num_spk: 2
nonlinear: relu
layer: 2
unit: 500
dropout: 0.1
emb_D: 40
alpha: 5.0
max_iteration: 100
threshold: 1.0e-05


62 changes: 62 additions & 0 deletions egs2/wsj0_2mix/enh1/conf/tuning/train_enh_mdc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
optim: adam
init: xavier_uniform
max_epoch: 100
batch_type: folded
batch_size: 8
num_workers: 4
optim_conf:
lr: 1.0e-03
eps: 1.0e-08
weight_decay: 1.0e-7
patience: 10
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- si_snr
- max
- - valid
- loss
- min
keep_nbest_models: 1
scheduler: reducelronplateau
scheduler_conf:
mode: min
factor: 0.7
patience: 1

# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# The first criterion
- name: dpcl
conf:
loss_type: mdc # "dpcl" or "mdc", "dpcl" means the origin loss in Deep Clustering and "mdc" means Manifold-Aware Deep Clustering
# the wrapper for the current criterion
# PIT is widely used in the speech separation task
wrapper: dpcl
wrapper_conf:
weight: 1.0

encoder: stft
encoder_conf:
n_fft: 256
hop_length: 128
decoder: stft
decoder_conf:
n_fft: 256
hop_length: 128
separator: dpcl
separator_conf:
rnn_type: blstm
num_spk: 2
nonlinear: relu
layer: 2
unit: 500
dropout: 0.1
emb_D: 40


16 changes: 12 additions & 4 deletions espnet2/enh/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from espnet2.enh.loss.criterions.time_domain import TimeDomainLoss
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper
from espnet2.enh.separator.abs_separator import AbsSeparator
from espnet2.enh.separator.dan_separator import DANSeparator
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel

Expand Down Expand Up @@ -134,12 +135,18 @@ def forward(
# for data-parallel
speech_ref = speech_ref[..., : speech_lengths.max()]
speech_ref = speech_ref.unbind(dim=1)
additional = {}
# Additional data is required in Deep Attractor Network
if isinstance(self.separator, DANSeparator):
additional["feature_ref"] = [
self.encoder(r, speech_lengths)[0] for r in speech_ref
]

speech_mix = speech_mix[:, : speech_lengths.max()]

# model forward
speech_pre, feature_mix, feature_pre, others = self.forward_enhance(
speech_mix, speech_lengths
speech_mix, speech_lengths, additional
)

# loss computation
Expand All @@ -159,9 +166,10 @@ def forward_enhance(
self,
speech_mix: torch.Tensor,
speech_lengths: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
feature_mix, flens = self.encoder(speech_mix, speech_lengths)
feature_pre, flens, others = self.separator(feature_mix, flens)
feature_pre, flens, others = self.separator(feature_mix, flens, additional)
if feature_pre is not None:
speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in feature_pre]
else:
Expand Down Expand Up @@ -192,7 +200,7 @@ def forward_loss(
# only select one channel as the reference
speech_ref = [sr[..., self.ref_channel] for sr in speech_ref]
# for the time domain criterions
l, s, o = loss_wrapper(speech_ref, speech_pre, o)
l, s, o = loss_wrapper(speech_ref, speech_pre, others)
elif isinstance(criterion, FrequencyDomainLoss):
# for the time-frequency domain criterions
if criterion.compute_on_mask:
Expand All @@ -219,7 +227,7 @@ def forward_loss(
tf_ref = [self.encoder(sr, speech_lengths)[0] for sr in speech_ref]
tf_pre = feature_pre

l, s, o = loss_wrapper(tf_ref, tf_pre, o)
l, s, o = loss_wrapper(tf_ref, tf_pre, others)
else:
raise NotImplementedError("Unsupported loss type: %s" % str(criterion))

Expand Down
Loading

0 comments on commit 72c1d8f

Please sign in to comment.