Skip to content

Commit

Permalink
Update code and add comments in separator
Browse files Browse the repository at this point in the history
  • Loading branch information
earthmanylf committed Mar 7, 2022
1 parent 5f86c11 commit 294373a
Show file tree
Hide file tree
Showing 11 changed files with 26 additions and 8 deletions.
2 changes: 2 additions & 0 deletions espnet2/enh/loss/wrappers/dpcl_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def forward(self, ref, inf, others={}):
Args:
ref (List[torch.Tensor]): [(batch, ...), ...] x n_spk
inf (List[torch.Tensor]): [(batch, ...), ...]
others (List): other data included in this solver
e.g. "tf_embedding" learned embedding of all T-F bins (B, T * F, D)
Returns:
loss: (torch.Tensor): minimum loss with the best permutation
Expand Down
1 change: 1 addition & 0 deletions espnet2/enh/separator/conformer_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def forward(
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
Expand Down
4 changes: 2 additions & 2 deletions espnet2/enh/separator/dan_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def forward(
Args:
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, F]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model,
e.g. additional["feature_ref"]: torch.Tensor(B, T, F)
additional (Dict or None): other data included in model
e.g. "feature_ref": list of reference spectra List[(B, T, F)]
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
Expand Down
2 changes: 2 additions & 0 deletions espnet2/enh/separator/dccrn_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def forward(
Args:
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, F]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, F), ...]
Expand Down
7 changes: 4 additions & 3 deletions espnet2/enh/separator/dpcl_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ def forward(
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, F]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
ilens (torch.Tensor): (B,)
others predicted data, e.g. tf_embedding: OrderedDict[
'tf_embedding': torch.Tensor(Batch, T * F, D),
'tf_embedding': learned embedding of all T-F bins (B, T * F, D),
]
"""

Expand All @@ -112,8 +113,8 @@ def forward(
else:
# K-means for batch
centers = tf_embedding[:, : self._num_spk, :].detach()
dist = torch.empty(B, T * F, self._num_spk).to(tf_embedding.device)
last_label = torch.zeros(B, T * F).to(tf_embedding.device)
dist = torch.empty(B, T * F, self._num_spk, device=tf_embedding.device)
last_label = torch.zeros(B, T * F, device=tf_embedding.device)
while True:
for i in range(self._num_spk):
dist[:, :, i] = torch.sum(
Expand Down
3 changes: 2 additions & 1 deletion espnet2/enh/separator/dprnn_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def forward(
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
ilens (torch.Tensor): (B,)
Expand Down
3 changes: 2 additions & 1 deletion espnet2/enh/separator/neural_beamformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def forward(
mixed speech [Batch, Frames, Channel, Freq]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
enhanced speech (single-channel): List[torch.complex64/ComplexTensor]
output lengths
Expand Down
1 change: 1 addition & 0 deletions espnet2/enh/separator/rnn_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def forward(
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
Expand Down
9 changes: 8 additions & 1 deletion espnet2/enh/separator/skim_separator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import OrderedDict
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

Expand Down Expand Up @@ -80,13 +82,18 @@ def __init__(
}[nonlinear]

def forward(
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor
self,
input: Union[torch.Tensor, ComplexTensor],
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
"""Forward.
Args:
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
Expand Down
1 change: 1 addition & 0 deletions espnet2/enh/separator/tcn_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def forward(
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
Expand Down
1 change: 1 addition & 0 deletions espnet2/enh/separator/transformer_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def forward(
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
Expand Down

0 comments on commit 294373a

Please sign in to comment.