Skip to content

Commit

Permalink
update for dpcl and dan
Browse files Browse the repository at this point in the history
  • Loading branch information
earthmanylf committed Feb 25, 2022
1 parent ee20e18 commit c1d9be5
Show file tree
Hide file tree
Showing 13 changed files with 905 additions and 4 deletions.
65 changes: 65 additions & 0 deletions egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dan_tf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
optim: adam
init: xavier_uniform
max_epoch: 100
batch_type: folded
batch_size: 8
iterator_type: chunk
chunk_length: 32000
num_workers: 4
optim_conf:
lr: 1.0e-04
eps: 1.0e-08
weight_decay: 1.0e-7
patience: 10
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- si_snr
- max
- - valid
- loss
- min
keep_nbest_models: 1
scheduler: reducelronplateau
scheduler_conf:
mode: min
factor: 0.7
patience: 1

# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# The first criterion
- name: mse
conf:
compute_on_mask: False
mask_type: PSM
# the wrapper for the current criterion
# PIT is widely used in the speech separation task
wrapper: pit
wrapper_conf:
weight: 1.0

encoder: stft
encoder_conf:
n_fft: 256
hop_length: 64
decoder: stft
decoder_conf:
n_fft: 256
hop_length: 64
separator: dan
separator_conf:
rnn_type: blstm
num_spk: 2
nonlinear: tanh
layer: 4
unit: 600
dropout: 0.1
emb_D: 20


62 changes: 62 additions & 0 deletions egs2/wsj0_2mix/enh1/conf/tuning/train_enh_dpcl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
optim: adam
init: xavier_uniform
max_epoch: 100
batch_type: folded
batch_size: 8
num_workers: 4
optim_conf:
lr: 1.0e-03
eps: 1.0e-08
weight_decay: 1.0e-7
patience: 10
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- si_snr
- max
- - valid
- loss
- min
keep_nbest_models: 1
scheduler: reducelronplateau
scheduler_conf:
mode: min
factor: 0.7
patience: 1

# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# The first criterion
- name: dpcl
conf:
loss_type: dpcl # "dpcl" or "mdc", "dpcl" means the origin loss in Deep Clustering and "mdc" means Manifold-Aware Deep Clustering
# the wrapper for the current criterion
# PIT is widely used in the speech separation task
wrapper: dpcl
wrapper_conf:
weight: 1.0

encoder: stft
encoder_conf:
n_fft: 256
hop_length: 128
decoder: stft
decoder_conf:
n_fft: 256
hop_length: 128
separator: dpcl
separator_conf:
rnn_type: blstm
num_spk: 2
nonlinear: relu
layer: 2
unit: 500
dropout: 0.1
emb_D: 40


62 changes: 62 additions & 0 deletions egs2/wsj0_2mix/enh1/conf/tuning/train_enh_mdc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
optim: adam
init: xavier_uniform
max_epoch: 100
batch_type: folded
batch_size: 8
num_workers: 4
optim_conf:
lr: 1.0e-03
eps: 1.0e-08
weight_decay: 1.0e-7
patience: 10
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- si_snr
- max
- - valid
- loss
- min
keep_nbest_models: 1
scheduler: reducelronplateau
scheduler_conf:
mode: min
factor: 0.7
patience: 1

# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# The first criterion
- name: dpcl
conf:
loss_type: mdc # "dpcl" or "mdc", "dpcl" means the origin loss in Deep Clustering and "mdc" means Manifold-Aware Deep Clustering
# the wrapper for the current criterion
# PIT is widely used in the speech separation task
wrapper: dpcl
wrapper_conf:
weight: 1.0

encoder: stft
encoder_conf:
n_fft: 256
hop_length: 128
decoder: stft
decoder_conf:
n_fft: 256
hop_length: 128
separator: dpcl
separator_conf:
rnn_type: blstm
num_spk: 2
nonlinear: relu
layer: 2
unit: 500
dropout: 0.1
emb_D: 40


8 changes: 5 additions & 3 deletions espnet2/enh/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,14 @@ def forward(
# for data-parallel
speech_ref = speech_ref[..., : speech_lengths.max()]
speech_ref = speech_ref.unbind(dim=1)
sep_others = {}
sep_others["feature_ref"] = [self.encoder(r, speech_lengths)[0] for r in speech_ref]

speech_mix = speech_mix[:, : speech_lengths.max()]

# model forward
feature_mix, flens = self.encoder(speech_mix, speech_lengths)
feature_pre, flens, others = self.separator(feature_mix, flens)
feature_pre, flens, others = self.separator(feature_mix, flens, sep_others)
if feature_pre is not None:
speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in feature_pre]
else:
Expand All @@ -156,7 +158,7 @@ def forward(
# only select one channel as the reference
speech_ref = [sr[..., self.ref_channel] for sr in speech_ref]
# for the time domain criterions
l, s, o = loss_wrapper(speech_ref, speech_pre, o)
l, s, o = loss_wrapper(speech_ref, speech_pre, others)
elif isinstance(criterion, FrequencyDomainLoss):
# for the time-frequency domain criterions
if criterion.compute_on_mask:
Expand All @@ -178,7 +180,7 @@ def forward(
tf_ref = [self.encoder(sr, speech_lengths)[0] for sr in speech_ref]
tf_pre = feature_pre

l, s, o = loss_wrapper(tf_ref, tf_pre, o)
l, s, o = loss_wrapper(tf_ref, tf_pre, others)
loss += l * loss_wrapper.weight
stats.update(s)

Expand Down
92 changes: 92 additions & 0 deletions espnet2/enh/loss/criterions/tf_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from functools import reduce

import torch
import torch.nn.functional as F
import math

from espnet2.enh.layers.complex_utils import is_complex
from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss
Expand Down Expand Up @@ -185,3 +187,93 @@ def forward(self, ref, inf) -> torch.Tensor:
"Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape)
)
return l1loss


class FrequencyDomainDPCL(FrequencyDomainLoss):
def __init__(self, compute_on_mask=False, mask_type="IBM", loss_type="dpcl"):
super().__init__()
self._compute_on_mask = compute_on_mask
self._mask_type = mask_type
self._loss_type = loss_type

@property
def compute_on_mask(self) -> bool:
return self._compute_on_mask

@property
def mask_type(self) -> str:
return self._mask_type

@property
def name(self) -> str:
return "dpcl"

def forward(self, ref, inf) -> torch.Tensor:
"""time-frequency Deep Clustering loss.
References:
[1] Deep clustering: Discriminative embeddings for segmentation and separation;
John R. Hershey. et al., 2016;
https://ieeexplore.ieee.org/document/7471631
[2] Manifold-Aware Deep Clustering: Maximizing Angles Between Embedding Vectors Based on Regular Simplex;
Tanaka, K. et al., 2021;
https://www.isca-speech.org/archive/interspeech_2021/tanaka21_interspeech.html
Args:
ref: List[(Batch, T, F) * spks]
inf: (Batch, T*F, D)
Returns:
loss: (Batch,)
"""
assert len(ref) > 0
num_spk = len(ref)

# Compute the ref for Deep Clustering[1][2]
if self._loss_type == "dpcl":
r = torch.zeros_like(abs(ref[0]))
B = ref[0].shape[0]
for i in range(0, num_spk):
flags = [abs(ref[i]) >= abs(n) for n in ref]
mask = reduce(lambda x, y: x * y, flags)
mask = mask.int() * i
r += mask
r = r.contiguous().view(-1,).long()
re = F.one_hot(r, num_classes=num_spk)
re = re.contiguous().view(B, -1, num_spk)
elif self._loss_type == "mdc":
B = ref[0].shape[0]
manifold_vector = (
torch.ones(num_spk, num_spk, device=inf.device)
* (-1 / num_spk)
* math.sqrt(num_spk / (num_spk - 1))
)
for i in range(num_spk):
manifold_vector[i][i] = ((num_spk - 1) / num_spk) * math.sqrt(
num_spk / (num_spk - 1)
)

re = torch.zeros(
ref[0].shape[0], ref[0].shape[1], ref[0].shape[2], num_spk
).to(inf.device)
for i in range(0, num_spk):
flags = [abs(ref[i]) >= abs(n) for n in ref]
mask = reduce(lambda x, y: x * y, flags)
mask = mask.int()
re[mask == 1] = manifold_vector[i]
re = re.contiguous().view(B, -1, num_spk)
else:
raise ValueError(
'Invalid loss type error: {}, the loss type must be "dpcl" or "mdc"'.format(
self._loss_type
)
)

V2 = torch.matmul(torch.transpose(inf, 2, 1), inf).pow(2).sum(dim=(1, 2))
Y2 = (
torch.matmul(torch.transpose(re, 2, 1).float(), re.float())
.pow(2)
.sum(dim=(1, 2))
)
VY = torch.matmul(torch.transpose(inf, 2, 1), re.float()).pow(2).sum(dim=(1, 2))

return V2 + Y2 - 2 * VY
32 changes: 32 additions & 0 deletions espnet2/enh/loss/wrappers/dpcl_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch

from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper


class DPCLSolver(AbsLossWrapper):
def __init__(self, criterion: AbsEnhLoss, weight=1.0):
super().__init__()
self.criterion = criterion
self.weight = weight

def forward(self, ref, inf, others={}):
"""An naive DPCL solver
Args:
ref (List[torch.Tensor]): [(batch, ...), ...] x n_spk
inf (List[torch.Tensor]): [(batch, ...), ...]
Returns:
loss: (torch.Tensor): minimum loss with the best permutation
stats: dict, for collecting training status
others: reserved
"""
assert "V" in others

loss = self.criterion(ref, others["V"]).mean()

stats = dict()
stats[self.criterion.name] = loss.detach()

return loss.mean(), stats, {}
Loading

0 comments on commit c1d9be5

Please sign in to comment.