Skip to content

Commit

Permalink
add subs to the abs_separator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
earthmanylf committed Feb 25, 2022
1 parent c1d9be5 commit c54d9a4
Show file tree
Hide file tree
Showing 18 changed files with 151 additions and 73 deletions.
8 changes: 5 additions & 3 deletions espnet2/enh/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,16 @@ def forward(
# for data-parallel
speech_ref = speech_ref[..., : speech_lengths.max()]
speech_ref = speech_ref.unbind(dim=1)
sep_others = {}
sep_others["feature_ref"] = [self.encoder(r, speech_lengths)[0] for r in speech_ref]
additional = {}
additional["feature_ref"] = [
self.encoder(r, speech_lengths)[0] for r in speech_ref
]

speech_mix = speech_mix[:, : speech_lengths.max()]

# model forward
feature_mix, flens = self.encoder(speech_mix, speech_lengths)
feature_pre, flens, others = self.separator(feature_mix, flens, sep_others)
feature_pre, flens, others = self.separator(feature_mix, flens, additional)
if feature_pre is not None:
speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in feature_pre]
else:
Expand Down
25 changes: 15 additions & 10 deletions espnet2/enh/loss/criterions/tf_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from abc import abstractmethod
from distutils.version import LooseVersion
from functools import reduce
import math

import torch
import torch.nn.functional as F
import math

from espnet2.enh.layers.complex_utils import is_complex
from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss
Expand Down Expand Up @@ -212,19 +212,19 @@ def forward(self, ref, inf) -> torch.Tensor:
"""time-frequency Deep Clustering loss.
References:
[1] Deep clustering: Discriminative embeddings for segmentation and separation;
John R. Hershey. et al., 2016;
[1] Deep clustering: Discriminative embeddings for segmentation and
separation; John R. Hershey. et al., 2016;
https://ieeexplore.ieee.org/document/7471631
[2] Manifold-Aware Deep Clustering: Maximizing Angles Between Embedding Vectors Based on Regular Simplex;
Tanaka, K. et al., 2021;
[2] Manifold-Aware Deep Clustering: Maximizing Angles Between Embedding
Vectors Based on Regular Simplex; Tanaka, K. et al., 2021;
https://www.isca-speech.org/archive/interspeech_2021/tanaka21_interspeech.html
Args:
ref: List[(Batch, T, F) * spks]
inf: (Batch, T*F, D)
Returns:
loss: (Batch,)
"""
""" # noqa: E501
assert len(ref) > 0
num_spk = len(ref)

Expand All @@ -237,7 +237,13 @@ def forward(self, ref, inf) -> torch.Tensor:
mask = reduce(lambda x, y: x * y, flags)
mask = mask.int() * i
r += mask
r = r.contiguous().view(-1,).long()
r = (
r.contiguous()
.view(
-1,
)
.long()
)
re = F.one_hot(r, num_classes=num_spk)
re = re.contiguous().view(B, -1, num_spk)
elif self._loss_type == "mdc":
Expand All @@ -263,9 +269,8 @@ def forward(self, ref, inf) -> torch.Tensor:
re = re.contiguous().view(B, -1, num_spk)
else:
raise ValueError(
'Invalid loss type error: {}, the loss type must be "dpcl" or "mdc"'.format(
self._loss_type
)
f"Invalid loss type error: {self._loss_type}, "
'the loss type must be "dpcl" or "mdc"'
)

V2 = torch.matmul(torch.transpose(inf, 2, 1), inf).pow(2).sum(dim=(1, 2))
Expand Down
2 changes: 0 additions & 2 deletions espnet2/enh/loss/wrappers/dpcl_solver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import torch

from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper

Expand Down
3 changes: 3 additions & 0 deletions espnet2/enh/separator/abs_separator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC
from abc import abstractmethod
from collections import OrderedDict
from typing import Dict
from typing import Optional
from typing import Tuple

import torch
Expand All @@ -12,6 +14,7 @@ def forward(
self,
input: torch.Tensor,
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[Tuple[torch.Tensor], torch.Tensor, OrderedDict]:

raise NotImplementedError
Expand Down
10 changes: 9 additions & 1 deletion espnet2/enh/separator/asteroid_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections import OrderedDict
from typing import Dict
from typing import Optional
from typing import Tuple
import warnings

Expand Down Expand Up @@ -66,12 +68,18 @@ def __init__(
if loss_type != "si_snr":
raise ValueError("Unsupported loss type: %s" % loss_type)

def forward(self, input: torch.Tensor, ilens: torch.Tensor = None):
def forward(
self,
input: torch.Tensor,
ilens: torch.Tensor = None,
additional: Optional[Dict] = None,
):
"""Whole forward of asteroid models.
Args:
input (torch.Tensor): Raw Waveforms [B, T]
ilens (torch.Tensor): input lengths [B]
additional (Dict or None): other data included in model
Returns:
estimated Waveforms(List[Union(torch.Tensor]): [(B, T), ...]
Expand Down
8 changes: 7 additions & 1 deletion espnet2/enh/separator/conformer_separator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import OrderedDict
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

Expand Down Expand Up @@ -118,13 +120,17 @@ def __init__(
}[nonlinear]

def forward(
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor
self,
input: Union[torch.Tensor, ComplexTensor],
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
"""Forward.
Args:
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
Expand Down
55 changes: 34 additions & 21 deletions espnet2/enh/separator/dan_separator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from collections import OrderedDict
from typing import Dict, List
from functools import reduce
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from functools import reduce

import torch
import torch.nn.functional as Fun
Expand All @@ -25,7 +27,7 @@ def __init__(
dropout: float = 0.0,
):
"""Deep Attractor Network Separator
Reference:
DEEP ATTRACTOR NETWORK FOR SINGLE-MICROPHONE SPEAKER SEPARATION;
Zhuo Chen. et al., 2017;
Expand All @@ -40,7 +42,7 @@ def __init__(
select from 'relu', 'tanh', 'sigmoid'
layer: int, number of stacked RNN layers. Default is 3.
unit: int, dimension of the hidden state.
emb_D: int, dimension of the attribute vector for one tf-bin.
emb_D: int, dimension of the attribute vector for one tf-bin.
dropout: float, dropout ratio. Default is 0.
"""
super().__init__()
Expand Down Expand Up @@ -70,14 +72,17 @@ def __init__(
self.D = emb_D

def forward(
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, o=None
self,
input: Union[torch.Tensor, ComplexTensor],
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
"""Forward.
Args:
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, F]
ilens (torch.Tensor): input lengths [Batch]
origin List[ComplexTensor(B, T, [C,] F), ...]: Origin data
additional (Dict or None): other data included in model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
Expand All @@ -100,19 +105,25 @@ def forward(
# x:(B, T, F*D)
x = self.nonlinear(x)
# V:(B, T*F, D)
V = x.contiguous().view(B, T*F, -1)
V = x.contiguous().view(B, T * F, -1)

# Compute the attractors
if self.training:
assert o is not None and "feature_ref" in o
origin = o["feature_ref"]
assert additional is not None and "feature_ref" in additional
origin = additional["feature_ref"]
Y_t = torch.zeros(B, T, F, device=origin[0].device)
for i in range(self._num_spk):
flags = [abs(origin[i]) >= abs(n) for n in origin]
Y = reduce(lambda x, y: x * y, flags)
Y = Y.int() * i
Y_t += Y
Y_t = Y_t.contiguous().view(-1,).long()
Y_t = (
Y_t.contiguous()
.view(
-1,
)
.long()
)
Y = Fun.one_hot(Y_t, num_classes=self._num_spk)
Y = Y.contiguous().view(B, -1, self._num_spk).float()

Expand All @@ -122,31 +133,33 @@ def forward(
sum_y = torch.sum(Y, 1, keepdim=True).expand_as(v_y)
# attractor:(B, D, spks)
attractor = v_y / (sum_y + 1e-8)
else:
else:
# K-means for batch
centers = V[:,:self._num_spk,:].detach()
dist = torch.empty(B, T*F, self._num_spk).to(V.device)
last_label = torch.zeros(B, T*F).to(V.device)
centers = V[:, : self._num_spk, :].detach()
dist = torch.empty(B, T * F, self._num_spk).to(V.device)
last_label = torch.zeros(B, T * F).to(V.device)
while True:
for i in range(self._num_spk):
dist[:,:,i] = torch.sum((V-centers[:,i,:].unsqueeze(1))**2, dim=2)
dist[:, :, i] = torch.sum(
(V - centers[:, i, :].unsqueeze(1)) ** 2, dim=2
)
label = dist.argmin(dim=2)
if torch.sum(label != last_label) == 0:
break
last_label = label
for b in range(B):
for i in range(self._num_spk):
centers[b,i] = V[b, label[b]==i].mean(dim=0)
attractor = centers.permute(0,2,1)
# calculate the distance between embeddings and attractors and generate the masks
centers[b, i] = V[b, label[b] == i].mean(dim=0)
attractor = centers.permute(0, 2, 1)

# calculate the distance between embeddings and attractors
# dist:(B, T*F, spks)
dist = torch.bmm(V, attractor)
masks = torch.softmax(dist,dim=2)
masks = torch.softmax(dist, dim=2)
masks = masks.contiguous().view(B, T, F, self._num_spk).unbind(dim=3)

masked = [input * m for m in masks]

others = OrderedDict(
zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks)
)
Expand Down
28 changes: 18 additions & 10 deletions espnet2/enh/separator/dpcl_separator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import OrderedDict
from typing import Dict, List
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

Expand All @@ -22,14 +24,14 @@ def __init__(
emb_D: int = 40,
dropout: float = 0.0,
):
"""Deep Clustering Separator
"""Deep Clustering Separator.
References:
[1] Deep clustering: Discriminative embeddings for segmentation and separation;
John R. Hershey. et al., 2016;
[1] Deep clustering: Discriminative embeddings for segmentation and
separation; John R. Hershey. et al., 2016;
https://ieeexplore.ieee.org/document/7471631
[2] Manifold-Aware Deep Clustering: Maximizing Angles Between Embedding Vectors Based on Regular Simplex;
Tanaka, K. et al., 2021;
[2] Manifold-Aware Deep Clustering: Maximizing Angles Between Embedding
Vectors Based on Regular Simplex; Tanaka, K. et al., 2021;
https://www.isca-speech.org/archive/interspeech_2021/tanaka21_interspeech.html
Args:
Expand All @@ -41,9 +43,9 @@ def __init__(
select from 'relu', 'tanh', 'sigmoid'
layer: int, number of stacked RNN layers. Default is 3.
unit: int, dimension of the hidden state.
emb_D: int, dimension of the feature vector for a tf-bin.
emb_D: int, dimension of the feature vector for a tf-bin.
dropout: float, dropout ratio. Default is 0.
"""
""" # noqa: E501
super().__init__()

self._num_spk = num_spk
Expand Down Expand Up @@ -71,13 +73,17 @@ def __init__(
self.D = emb_D

def forward(
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, o=None
self,
input: Union[torch.Tensor, ComplexTensor],
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
"""Forward.
Args:
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, F]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
Expand Down Expand Up @@ -125,7 +131,9 @@ def forward(
for i in range(self._num_spk):
masked.append(input * (label == i))

others = OrderedDict({"V": V},)
others = OrderedDict(
{"V": V},
)

return masked, ilens, others

Expand Down
8 changes: 7 additions & 1 deletion espnet2/enh/separator/dprnn_separator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import OrderedDict
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

Expand Down Expand Up @@ -70,13 +72,17 @@ def __init__(
}[nonlinear]

def forward(
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor
self,
input: Union[torch.Tensor, ComplexTensor],
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
"""Forward.
Args:
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
Expand Down
Loading

0 comments on commit c54d9a4

Please sign in to comment.