From 5ea4e087a311ab7c798950e68ae92e10b1bb41d8 Mon Sep 17 00:00:00 2001 From: Wangyou Zhang Date: Sat, 7 May 2022 12:05:49 +0800 Subject: [PATCH 1/2] Fix a bug in stats aggregation when PITSolver is used --- espnet2/enh/loss/wrappers/pit_solver.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/espnet2/enh/loss/wrappers/pit_solver.py b/espnet2/enh/loss/wrappers/pit_solver.py index 9cb810f5c9b..8b822cfeede 100644 --- a/espnet2/enh/loss/wrappers/pit_solver.py +++ b/espnet2/enh/loss/wrappers/pit_solver.py @@ -73,12 +73,15 @@ def pair_loss(permutation): ) # remove stats from unused permutations for k, v in stats.items(): - # (B, len(all_permutations), ...) + # (B, num_spk * len(all_permutations), ...) new_v = torch.stack(v, dim=1) + B, L, *rest = new_v.shape + assert L == num_spk * len(all_permutations), (L, num_spk) + new_v = new_v.view(B, L // num_spk, num_spk, *rest).mean(2) if new_v.dim() > 2: - shapes = [1 for _ in range(new_v.dim() - 2)] + shapes = [1 for _ in rest] perm0 = perm_.view(perm_.shape[0], 1, *shapes).expand( - -1, -1, *new_v.shape[2:] + -1, -1, *rest ) else: perm0 = perm_.unsqueeze(1) From 9e8e753154f5f71c9cb26217483427adb278759c Mon Sep 17 00:00:00 2001 From: Wangyou Zhang Date: Sat, 7 May 2022 13:16:35 +0800 Subject: [PATCH 2/2] Apply black --- espnet2/enh/loss/wrappers/pit_solver.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/espnet2/enh/loss/wrappers/pit_solver.py b/espnet2/enh/loss/wrappers/pit_solver.py index 8b822cfeede..eab7f5e97a4 100644 --- a/espnet2/enh/loss/wrappers/pit_solver.py +++ b/espnet2/enh/loss/wrappers/pit_solver.py @@ -80,9 +80,7 @@ def pair_loss(permutation): new_v = new_v.view(B, L // num_spk, num_spk, *rest).mean(2) if new_v.dim() > 2: shapes = [1 for _ in rest] - perm0 = perm_.view(perm_.shape[0], 1, *shapes).expand( - -1, -1, *rest - ) + perm0 = perm_.view(perm_.shape[0], 1, *shapes).expand(-1, -1, *rest) else: perm0 = perm_.unsqueeze(1) stats[k] = new_v.gather(1, perm0.to(device=new_v.device)).unbind(1)