From ffe824a6df26169bc408c6064f60f8b97c105720 Mon Sep 17 00:00:00 2001 From: Changsheng Quan Date: Mon, 28 Feb 2022 20:46:42 +0800 Subject: [PATCH] Improved shape checking of `permutation_invariant_training` (#864) * fix * update change log * pep 8 Co-authored-by: quancs --- tests/audio/test_pit.py | 5 ++++- torchmetrics/functional/audio/pit.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/audio/test_pit.py b/tests/audio/test_pit.py index c668f3a77b7..6bc8c76a832 100644 --- a/tests/audio/test_pit.py +++ b/tests/audio/test_pit.py @@ -171,7 +171,10 @@ def test_pit_half_gpu(self, preds, target, sk_metric, metric_func, eval_func): def test_error_on_different_shape() -> None: metric = PermutationInvariantTraining(signal_noise_ratio, "max") - with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): + with pytest.raises( + RuntimeError, + match="Predictions and targets are expected to have the same shape at the batch and speaker dimensions", + ): metric(torch.randn(3, 3, 10), torch.randn(3, 2, 10)) diff --git a/torchmetrics/functional/audio/pit.py b/torchmetrics/functional/audio/pit.py index 69eef5b19e7..d0e3eb3d44d 100644 --- a/torchmetrics/functional/audio/pit.py +++ b/torchmetrics/functional/audio/pit.py @@ -18,7 +18,6 @@ import torch from torch import Tensor -from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.imports import _SCIPY_AVAILABLE # _ps_dict: cache of permutations @@ -145,7 +144,10 @@ def permutation_invariant_training( Reference: [1] `Permutation Invariant Training of Deep Models`_ """ - _check_same_shape(preds, target) + if preds.shape[0:2] != target.shape[0:2]: + raise RuntimeError( + "Predictions and targets are expected to have the same shape at the batch and speaker dimensions" + ) if eval_func not in ["max", "min"]: raise ValueError(f'eval_func can only be "max" or "min" but got {eval_func}') if target.ndim < 2: