Skip to content

Commit

Permalink
Fixed format in some files
Browse files Browse the repository at this point in the history
  • Loading branch information
earthmanylf committed Mar 7, 2022
1 parent 294373a commit 3e6167c
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 12 deletions.
14 changes: 11 additions & 3 deletions espnet2/enh/loss/criterions/tf_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def forward(self, ref, inf) -> torch.Tensor:
mask = reduce(lambda x, y: x * y, flags)
mask = mask.int() * i
r += mask
r = r.contiguous().view(-1,).long()
r = r.contiguous().flatten().long()
re = F.one_hot(r, num_classes=num_spk)
re = re.contiguous().view(B, -1, num_spk)
elif self._loss_type == "mdc":
Expand All @@ -255,7 +255,11 @@ def forward(self, ref, inf) -> torch.Tensor:
)

re = torch.zeros(
ref[0].shape[0], ref[0].shape[1], ref[0].shape[2], num_spk, device=inf.device
ref[0].shape[0],
ref[0].shape[1],
ref[0].shape[2],
num_spk,
device=inf.device,
)
for i in range(num_spk):
flags = [abs_ref[i] >= n for n in abs_ref]
Expand All @@ -270,7 +274,11 @@ def forward(self, ref, inf) -> torch.Tensor:
)

V2 = torch.matmul(torch.transpose(inf, 2, 1), inf).pow(2).sum(dim=(1, 2))
Y2 = torch.matmul(torch.transpose(re, 2, 1).float(), re.float()).pow(2).sum(dim=(1, 2))
Y2 = (
torch.matmul(torch.transpose(re, 2, 1).float(), re.float())
.pow(2)
.sum(dim=(1, 2))
)
VY = torch.matmul(torch.transpose(inf, 2, 1), re.float()).pow(2).sum(dim=(1, 2))

return V2 + Y2 - 2 * VY
3 changes: 1 addition & 2 deletions espnet2/enh/separator/dan_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def forward(
'mask_spkn': torch.Tensor(Batch, Frames, Freq),
]
"""

# if complex spectrum,
if isinstance(input, ComplexTensor):
feature = abs(input)
Expand All @@ -122,7 +121,7 @@ def forward(
Y = reduce(lambda x, y: x * y, flags)
Y = Y.int() * i
Y_t += Y
Y_t = Y_t.contiguous().view(-1,).long()
Y_t = Y_t.contiguous().flatten().long()
Y = Fun.one_hot(Y_t, num_classes=self._num_spk)
Y = Y.contiguous().view(B, -1, self._num_spk).float()

Expand Down
4 changes: 2 additions & 2 deletions espnet2/enh/separator/dccrn_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def __init__(
self.flatten_parameters()

def forward(
self,
input: Union[torch.Tensor, ComplexTensor],
self,
input: Union[torch.Tensor, ComplexTensor],
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
Expand Down
1 change: 0 additions & 1 deletion espnet2/enh/separator/dpcl_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def forward(
'tf_embedding': learned embedding of all T-F bins (B, T * F, D),
]
"""

# if complex spectrum,
if isinstance(input, ComplexTensor):
feature = abs(input)
Expand Down
2 changes: 1 addition & 1 deletion espnet2/enh/separator/dprnn_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def forward(
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
ilens (torch.Tensor): (B,)
Expand Down
2 changes: 1 addition & 1 deletion espnet2/enh/separator/neural_beamformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def forward(
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
enhanced speech (single-channel): List[torch.complex64/ComplexTensor]
output lengths
Expand Down
4 changes: 2 additions & 2 deletions espnet2/enh/separator/skim_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def __init__(
}[nonlinear]

def forward(
self,
input: Union[torch.Tensor, ComplexTensor],
self,
input: Union[torch.Tensor, ComplexTensor],
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
Expand Down

0 comments on commit 3e6167c

Please sign in to comment.